Zhizhou Zhong commited on
Commit
f0d6854
·
unverified ·
1 Parent(s): 03f211f

feat: data preprocessing and training (#294)

Browse files

* docs: update readme

* docs: update readme

* feat: training codes

* feat: data preprocess

* docs: release training

.gitignore CHANGED
@@ -8,4 +8,8 @@ results/
8
  ./models
9
  **/__pycache__/
10
  *.py[cod]
11
- *$py.class
 
 
 
 
 
8
  ./models
9
  **/__pycache__/
10
  *.py[cod]
11
+ *$py.class
12
+ dataset/
13
+ ffmpeg*
14
+ debug
15
+ exp_out
README.md CHANGED
@@ -130,8 +130,9 @@ https://github.com/user-attachments/assets/b011ece9-a332-4bc1-b8b7-ef6e383d7bde
130
  - [x] codes for real-time inference.
131
  - [x] [technical report](https://arxiv.org/abs/2410.10122v2).
132
  - [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
133
- - [x] realtime inference code for 1.5 version (Note: MuseTalk 1.5 has the same computation time as 1.0 and supports real-time inference. The code implementation will be released soon).
134
- - [ ] training and dataloader code (Expected completion on 04/04/2025).
 
135
 
136
 
137
  # Getting Started
@@ -187,6 +188,7 @@ huggingface-cli download TMElyralab/MuseTalk --local-dir models/
187
  - [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
188
  - [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
189
  - [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
 
190
 
191
 
192
  Finally, these weights should be organized in `models` as follows:
@@ -198,6 +200,8 @@ Finally, these weights should be organized in `models` as follows:
198
  ├── musetalkV15
199
  │ └── musetalk.json
200
  │ └── unet.pth
 
 
201
  ├── dwpose
202
  │ └── dw-ll_ucoco_384.pth
203
  ├── face-parse-bisent
@@ -265,6 +269,73 @@ For faster generation without saving images, you can use:
265
  python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
266
  ```
267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  ## TestCases For 1.0
269
  <table class="center">
270
  <tr style="font-weight: bolder;text-align:center;">
@@ -368,7 +439,7 @@ python -m scripts.inference --inference_config configs/inference/test.yaml --bbo
368
  As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
369
 
370
  # Acknowledgement
371
- 1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
372
  1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).
373
  1. MuseTalk has been built on [HDTF](https://github.com/MRzzm/HDTF) datasets.
374
 
 
130
  - [x] codes for real-time inference.
131
  - [x] [technical report](https://arxiv.org/abs/2410.10122v2).
132
  - [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
133
+ - [x] realtime inference code for 1.5 version.
134
+ - [x] training and data preprocessing codes.
135
+ - [ ] **always** welcome to submit issues and PRs to improve this repository! 😊
136
 
137
 
138
  # Getting Started
 
188
  - [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
189
  - [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
190
  - [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
191
+ - [syncnet](https://huggingface.co/ByteDance/LatentSync/tree/main)
192
 
193
 
194
  Finally, these weights should be organized in `models` as follows:
 
200
  ├── musetalkV15
201
  │ └── musetalk.json
202
  │ └── unet.pth
203
+ ├── syncnet
204
+ │ └── latentsync_syncnet.pt
205
  ├── dwpose
206
  │ └── dw-ll_ucoco_384.pth
207
  ├── face-parse-bisent
 
269
  python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
270
  ```
271
 
272
+ ## Training
273
+
274
+ ### Data Preparation
275
+ To train MuseTalk, you need to prepare your dataset following these steps:
276
+
277
+ 1. **Place your source videos**
278
+
279
+ For example, if you're using the HDTF dataset, place all your video files in `./dataset/HDTF/source`.
280
+
281
+ 2. **Run the preprocessing script**
282
+ ```bash
283
+ python -m scripts.preprocess --config ./configs/training/preprocess.yaml
284
+ ```
285
+ This script will:
286
+ - Extract frames from videos
287
+ - Detect and align faces
288
+ - Generate audio features
289
+ - Create the necessary data structure for training
290
+
291
+ ### Training Process
292
+ After data preprocessing, you can start the training process:
293
+
294
+ 1. **First Stage**
295
+ ```bash
296
+ sh train.sh stage1
297
+ ```
298
+
299
+ 2. **Second Stage**
300
+ ```bash
301
+ sh train.sh stage2
302
+ ```
303
+
304
+ ### Configuration Adjustment
305
+ Before starting the training, you should adjust the configuration files according to your hardware and requirements:
306
+
307
+ 1. **GPU Configuration** (`configs/training/gpu.yaml`):
308
+ - `gpu_ids`: Specify the GPU IDs you want to use (e.g., "0,1,2,3")
309
+ - `num_processes`: Set this to match the number of GPUs you're using
310
+
311
+ 2. **Stage 1 Configuration** (`configs/training/stage1.yaml`):
312
+ - `data.train_bs`: Adjust batch size based on your GPU memory (default: 32)
313
+ - `data.n_sample_frames`: Number of sampled frames per video (default: 1)
314
+
315
+ 3. **Stage 2 Configuration** (`configs/training/stage2.yaml`):
316
+ - `random_init_unet`: Must be set to `False` to use the model from stage 1
317
+ - `data.train_bs`: Smaller batch size due to high GPU memory cost (default: 2)
318
+ - `data.n_sample_frames`: Higher value for temporal consistency (default: 16)
319
+ - `solver.gradient_accumulation_steps`: Increase to simulate larger batch sizes (default: 8)
320
+
321
+
322
+ ### GPU Memory Requirements
323
+ Based on our testing on a machine with 8 NVIDIA H20 GPUs:
324
+
325
+ #### Stage 1 Memory Usage
326
+ | Batch Size | Gradient Accumulation | Memory per GPU | Recommendation |
327
+ |:----------:|:----------------------:|:--------------:|:--------------:|
328
+ | 8 | 1 | ~32GB | |
329
+ | 16 | 1 | ~45GB | |
330
+ | 32 | 1 | ~74GB | ✓ |
331
+
332
+ #### Stage 2 Memory Usage
333
+ | Batch Size | Gradient Accumulation | Memory per GPU | Recommendation |
334
+ |:----------:|:----------------------:|:--------------:|:--------------:|
335
+ | 1 | 8 | ~54GB | |
336
+ | 2 | 2 | ~80GB | |
337
+ | 2 | 8 | ~85GB | ✓ |
338
+
339
  ## TestCases For 1.0
340
  <table class="center">
341
  <tr style="font-weight: bolder;text-align:center;">
 
439
  As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
440
 
441
  # Acknowledgement
442
+ 1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch) and [LatentSync](https://huggingface.co/ByteDance/LatentSync/tree/main).
443
  1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).
444
  1. MuseTalk has been built on [HDTF](https://github.com/MRzzm/HDTF) datasets.
445
 
configs/training/gpu.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: True
3
+ deepspeed_config:
4
+ offload_optimizer_device: none
5
+ offload_param_device: none
6
+ zero3_init_flag: False
7
+ zero_stage: 2
8
+
9
+ distributed_type: DEEPSPEED
10
+ downcast_bf16: 'no'
11
+ gpu_ids: "5, 7" # modify this according to your GPU number
12
+ machine_rank: 0
13
+ main_training_function: main
14
+ num_machines: 1
15
+ num_processes: 2 # it should be the same as the number of GPUs
16
+ rdzv_backend: static
17
+ same_network: true
18
+ tpu_env: []
19
+ tpu_use_cluster: false
20
+ tpu_use_sudo: false
21
+ use_cpu: false
configs/training/preprocess.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip_len_second: 30 # the length of the video clip
2
+ video_root_raw: "./dataset/HDTF/source/" # the path of the original video
3
+ val_list_hdtf:
4
+ - RD_Radio7_000
5
+ - RD_Radio8_000
6
+ - RD_Radio9_000
7
+ - WDA_TinaSmith_000
8
+ - WDA_TomCarper_000
9
+ - WDA_TomPerez_000
10
+ - WDA_TomUdall_000
11
+ - WDA_VeronicaEscobar0_000
12
+ - WDA_VeronicaEscobar1_000
13
+ - WDA_WhipJimClyburn_000
14
+ - WDA_XavierBecerra_000
15
+ - WDA_XavierBecerra_001
16
+ - WDA_XavierBecerra_002
17
+ - WDA_ZoeLofgren_000
18
+ - WRA_SteveScalise1_000
19
+ - WRA_TimScott_000
20
+ - WRA_ToddYoung_000
21
+ - WRA_TomCotton_000
22
+ - WRA_TomPrice_000
23
+ - WRA_VickyHartzler_000
24
+
25
+ # following dir will be automatically generated
26
+ video_root_25fps: "./dataset/HDTF/video_root_25fps/"
27
+ video_file_list: "./dataset/HDTF/video_file_list.txt"
28
+ video_audio_clip_root: "./dataset/HDTF/video_audio_clip_root/"
29
+ meta_root: "./dataset/HDTF/meta/"
30
+ video_clip_file_list_train: "./dataset/HDTF/train.txt"
31
+ video_clip_file_list_val: "./dataset/HDTF/val.txt"
configs/training/stage1.yaml ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exp_name: 'test' # Name of the experiment
2
+ output_dir: './exp_out/stage1/' # Directory to save experiment outputs
3
+ unet_sub_folder: musetalk # Subfolder name for UNet model
4
+ random_init_unet: True # Whether to randomly initialize UNet (stage1) or use pretrained weights (stage2)
5
+ whisper_path: "./models/whisper" # Path to the Whisper model
6
+ pretrained_model_name_or_path: "./models" # Path to pretrained models
7
+ resume_from_checkpoint: True # Whether to resume training from a checkpoint
8
+ padding_pixel_mouth: 10 # Number of pixels to pad around the mouth region
9
+ vae_type: "sd-vae" # Type of VAE model to use
10
+ # Validation parameters
11
+ num_images_to_keep: 8 # Number of validation images to keep
12
+ ref_dropout_rate: 0 # Dropout rate for reference images
13
+ syncnet_config_path: "./configs/training/syncnet.yaml" # Path to SyncNet configuration
14
+ use_adapted_weight: False # Whether to use adapted weights for loss calculation
15
+ cropping_jaw2edge_margin_mean: 10 # Mean margin for jaw-to-edge cropping
16
+ cropping_jaw2edge_margin_std: 10 # Standard deviation for jaw-to-edge cropping
17
+ crop_type: "crop_resize" # Type of cropping method
18
+ random_margin_method: "normal" # Method for random margin generation
19
+ num_backward_frames: 16 # Number of frames to use for backward pass in SyncNet
20
+
21
+ data:
22
+ dataset_key: "HDTF" # Dataset to use for training
23
+ train_bs: 32 # Training batch size (actual batch size is train_bs*n_sample_frames)
24
+ image_size: 256 # Size of input images
25
+ n_sample_frames: 1 # Number of frames to sample per batch
26
+ num_workers: 8 # Number of data loading workers
27
+ audio_padding_length_left: 2 # Left padding length for audio features
28
+ audio_padding_length_right: 2 # Right padding length for audio features
29
+ sample_method: pose_similarity_and_mouth_dissimilarity # Method for sampling frames
30
+ top_k_ratio: 0.51 # Ratio for top-k sampling
31
+ contorl_face_min_size: True # Whether to control minimum face size
32
+ min_face_size: 150 # Minimum face size in pixels
33
+
34
+ loss_params:
35
+ l1_loss: 1.0 # Weight for L1 loss
36
+ vgg_loss: 0.01 # Weight for VGG perceptual loss
37
+ vgg_layer_weight: [1, 1, 1, 1, 1] # Weights for different VGG layers
38
+ pyramid_scale: [1, 0.5, 0.25, 0.125] # Scales for image pyramid
39
+ gan_loss: 0 # Weight for GAN loss
40
+ fm_loss: [1.0, 1.0, 1.0, 1.0] # Weights for feature matching loss
41
+ sync_loss: 0 # Weight for sync loss
42
+ mouth_gan_loss: 0 # Weight for mouth-specific GAN loss
43
+
44
+ model_params:
45
+ discriminator_params:
46
+ scales: [1] # Scales for discriminator
47
+ block_expansion: 32 # Expansion factor for discriminator blocks
48
+ max_features: 512 # Maximum number of features in discriminator
49
+ num_blocks: 4 # Number of blocks in discriminator
50
+ sn: True # Whether to use spectral normalization
51
+ image_channel: 3 # Number of image channels
52
+ estimate_jacobian: False # Whether to estimate Jacobian
53
+
54
+ discriminator_train_params:
55
+ lr: 0.000005 # Learning rate for discriminator
56
+ eps: 0.00000001 # Epsilon for optimizer
57
+ weight_decay: 0.01 # Weight decay for optimizer
58
+ patch_size: 1 # Size of patches for discriminator
59
+ betas: [0.5, 0.999] # Beta parameters for Adam optimizer
60
+ epochs: 10000 # Number of training epochs
61
+ start_gan: 1000 # Step to start GAN training
62
+
63
+ solver:
64
+ gradient_accumulation_steps: 1 # Number of steps for gradient accumulation
65
+ uncond_steps: 10 # Number of unconditional steps
66
+ mixed_precision: 'fp32' # Precision mode for training
67
+ enable_xformers_memory_efficient_attention: True # Whether to use memory efficient attention
68
+ gradient_checkpointing: True # Whether to use gradient checkpointing
69
+ max_train_steps: 250000 # Maximum number of training steps
70
+ max_grad_norm: 1.0 # Maximum gradient norm for clipping
71
+ # Learning rate parameters
72
+ learning_rate: 2.0e-5 # Base learning rate
73
+ scale_lr: False # Whether to scale learning rate
74
+ lr_warmup_steps: 1000 # Number of warmup steps for learning rate
75
+ lr_scheduler: "linear" # Type of learning rate scheduler
76
+ # Optimizer parameters
77
+ use_8bit_adam: False # Whether to use 8-bit Adam optimizer
78
+ adam_beta1: 0.5 # Beta1 parameter for Adam optimizer
79
+ adam_beta2: 0.999 # Beta2 parameter for Adam optimizer
80
+ adam_weight_decay: 1.0e-2 # Weight decay for Adam optimizer
81
+ adam_epsilon: 1.0e-8 # Epsilon for Adam optimizer
82
+
83
+ total_limit: 10 # Maximum number of checkpoints to keep
84
+ save_model_epoch_interval: 250000 # Interval between model saves
85
+ checkpointing_steps: 10000 # Number of steps between checkpoints
86
+ val_freq: 2000 # Frequency of validation
87
+
88
+ seed: 41 # Random seed for reproducibility
89
+
configs/training/stage2.yaml ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exp_name: 'test' # Name of the experiment
2
+ output_dir: './exp_out/stage2/' # Directory to save experiment outputs
3
+ unet_sub_folder: musetalk # Subfolder name for UNet model
4
+ random_init_unet: False # Whether to randomly initialize UNet (stage1) or use pretrained weights (stage2)
5
+ whisper_path: "./models/whisper" # Path to the Whisper model
6
+ pretrained_model_name_or_path: "./models" # Path to pretrained models
7
+ resume_from_checkpoint: True # Whether to resume training from a checkpoint
8
+ padding_pixel_mouth: 10 # Number of pixels to pad around the mouth region
9
+ vae_type: "sd-vae" # Type of VAE model to use
10
+ # Validation parameters
11
+ num_images_to_keep: 8 # Number of validation images to keep
12
+ ref_dropout_rate: 0 # Dropout rate for reference images
13
+ syncnet_config_path: "./configs/training/syncnet.yaml" # Path to SyncNet configuration
14
+ use_adapted_weight: False # Whether to use adapted weights for loss calculation
15
+ cropping_jaw2edge_margin_mean: 10 # Mean margin for jaw-to-edge cropping
16
+ cropping_jaw2edge_margin_std: 10 # Standard deviation for jaw-to-edge cropping
17
+ crop_type: "dynamic_margin_crop_resize" # Type of cropping method
18
+ random_margin_method: "normal" # Method for random margin generation
19
+ num_backward_frames: 16 # Number of frames to use for backward pass in SyncNet
20
+
21
+ data:
22
+ dataset_key: "HDTF" # Dataset to use for training
23
+ train_bs: 2 # Training batch size (actual batch size is train_bs*n_sample_frames)
24
+ image_size: 256 # Size of input images
25
+ n_sample_frames: 16 # Number of frames to sample per batch
26
+ num_workers: 8 # Number of data loading workers
27
+ audio_padding_length_left: 2 # Left padding length for audio features
28
+ audio_padding_length_right: 2 # Right padding length for audio features
29
+ sample_method: pose_similarity_and_mouth_dissimilarity # Method for sampling frames
30
+ top_k_ratio: 0.51 # Ratio for top-k sampling
31
+ contorl_face_min_size: True # Whether to control minimum face size
32
+ min_face_size: 200 # Minimum face size in pixels
33
+
34
+ loss_params:
35
+ l1_loss: 1.0 # Weight for L1 loss
36
+ vgg_loss: 0.01 # Weight for VGG perceptual loss
37
+ vgg_layer_weight: [1, 1, 1, 1, 1] # Weights for different VGG layers
38
+ pyramid_scale: [1, 0.5, 0.25, 0.125] # Scales for image pyramid
39
+ gan_loss: 0.01 # Weight for GAN loss
40
+ fm_loss: [1.0, 1.0, 1.0, 1.0] # Weights for feature matching loss
41
+ sync_loss: 0.05 # Weight for sync loss
42
+ mouth_gan_loss: 0.01 # Weight for mouth-specific GAN loss
43
+
44
+ model_params:
45
+ discriminator_params:
46
+ scales: [1] # Scales for discriminator
47
+ block_expansion: 32 # Expansion factor for discriminator blocks
48
+ max_features: 512 # Maximum number of features in discriminator
49
+ num_blocks: 4 # Number of blocks in discriminator
50
+ sn: True # Whether to use spectral normalization
51
+ image_channel: 3 # Number of image channels
52
+ estimate_jacobian: False # Whether to estimate Jacobian
53
+
54
+ discriminator_train_params:
55
+ lr: 0.000005 # Learning rate for discriminator
56
+ eps: 0.00000001 # Epsilon for optimizer
57
+ weight_decay: 0.01 # Weight decay for optimizer
58
+ patch_size: 1 # Size of patches for discriminator
59
+ betas: [0.5, 0.999] # Beta parameters for Adam optimizer
60
+ epochs: 10000 # Number of training epochs
61
+ start_gan: 1000 # Step to start GAN training
62
+
63
+ solver:
64
+ gradient_accumulation_steps: 8 # Number of steps for gradient accumulation
65
+ uncond_steps: 10 # Number of unconditional steps
66
+ mixed_precision: 'fp32' # Precision mode for training
67
+ enable_xformers_memory_efficient_attention: True # Whether to use memory efficient attention
68
+ gradient_checkpointing: True # Whether to use gradient checkpointing
69
+ max_train_steps: 250000 # Maximum number of training steps
70
+ max_grad_norm: 1.0 # Maximum gradient norm for clipping
71
+ # Learning rate parameters
72
+ learning_rate: 5.0e-6 # Base learning rate
73
+ scale_lr: False # Whether to scale learning rate
74
+ lr_warmup_steps: 1000 # Number of warmup steps for learning rate
75
+ lr_scheduler: "linear" # Type of learning rate scheduler
76
+ # Optimizer parameters
77
+ use_8bit_adam: False # Whether to use 8-bit Adam optimizer
78
+ adam_beta1: 0.5 # Beta1 parameter for Adam optimizer
79
+ adam_beta2: 0.999 # Beta2 parameter for Adam optimizer
80
+ adam_weight_decay: 1.0e-2 # Weight decay for Adam optimizer
81
+ adam_epsilon: 1.0e-8 # Epsilon for Adam optimizer
82
+
83
+ total_limit: 10 # Maximum number of checkpoints to keep
84
+ save_model_epoch_interval: 250000 # Interval between model saves
85
+ checkpointing_steps: 2000 # Number of steps between checkpoints
86
+ val_freq: 2000 # Frequency of validation
87
+
88
+ seed: 41 # Random seed for reproducibility
89
+
configs/training/syncnet.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from LatentSync (https://github.com/bytedance/LatentSync/blob/main/latentsync/configs/training/syncnet_16_pixel.yaml).
2
+ model:
3
+ audio_encoder: # input (1, 80, 52)
4
+ in_channels: 1
5
+ block_out_channels: [32, 64, 128, 256, 512, 1024, 2048]
6
+ downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]]
7
+ attn_blocks: [0, 0, 0, 0, 0, 0, 0]
8
+ dropout: 0.0
9
+ visual_encoder: # input (48, 128, 256)
10
+ in_channels: 48
11
+ block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048]
12
+ downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
13
+ attn_blocks: [0, 0, 0, 0, 0, 0, 0, 0]
14
+ dropout: 0.0
15
+
16
+ ckpt:
17
+ resume_ckpt_path: ""
18
+ inference_ckpt_path: ./models/syncnet/latentsync_syncnet.pt # this pretrained model is from LatentSync (https://huggingface.co/ByteDance/LatentSync/tree/main)
19
+ save_ckpt_steps: 2500
inference.sh CHANGED
@@ -59,7 +59,7 @@ cmd_args="--inference_config $config_path \
59
  --result_dir $result_dir \
60
  --unet_model_path $unet_model_path \
61
  --unet_config $unet_config \
62
- --version $version_ar"
63
 
64
  # Add realtime-specific arguments if in realtime mode
65
  if [ "$mode" = "realtime" ]; then
 
59
  --result_dir $result_dir \
60
  --unet_model_path $unet_model_path \
61
  --unet_config $unet_config \
62
+ --version $version_arg"
63
 
64
  # Add realtime-specific arguments if in realtime mode
65
  if [ "$mode" = "realtime" ]; then
musetalk/data/audio.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ from scipy import signal
5
+ from scipy.io import wavfile
6
+
7
+ class HParams:
8
+ # copy from wav2lip
9
+ def __init__(self):
10
+ self.n_fft = 800
11
+ self.hop_size = 200
12
+ self.win_size = 800
13
+ self.sample_rate = 16000
14
+ self.frame_shift_ms = None
15
+ self.signal_normalization = True
16
+
17
+ self.allow_clipping_in_normalization = True
18
+ self.symmetric_mels = True
19
+ self.max_abs_value = 4.0
20
+ self.preemphasize = True
21
+ self.preemphasis = 0.97
22
+ self.min_level_db = -100
23
+ self.ref_level_db = 20
24
+ self.fmin = 55
25
+ self.fmax=7600
26
+
27
+ self.use_lws=False
28
+ self.num_mels=80 # Number of mel-spectrogram channels and local conditioning dimensionality
29
+ self.rescale=True # Whether to rescale audio prior to preprocessing
30
+ self.rescaling_max=0.9 # Rescaling value
31
+ self.use_lws=False
32
+
33
+
34
+ hp = HParams()
35
+
36
+ def load_wav(path, sr):
37
+ return librosa.core.load(path, sr=sr)[0]
38
+ #def load_wav(path, sr):
39
+ # audio, sr_native = sf.read(path)
40
+ # if sr != sr_native:
41
+ # audio = librosa.resample(audio.T, sr_native, sr).T
42
+ # return audio
43
+
44
+ def save_wav(wav, path, sr):
45
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
46
+ #proposed by @dsmiller
47
+ wavfile.write(path, sr, wav.astype(np.int16))
48
+
49
+ def save_wavenet_wav(wav, path, sr):
50
+ librosa.output.write_wav(path, wav, sr=sr)
51
+
52
+ def preemphasis(wav, k, preemphasize=True):
53
+ if preemphasize:
54
+ return signal.lfilter([1, -k], [1], wav)
55
+ return wav
56
+
57
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
58
+ if inv_preemphasize:
59
+ return signal.lfilter([1], [1, -k], wav)
60
+ return wav
61
+
62
+ def get_hop_size():
63
+ hop_size = hp.hop_size
64
+ if hop_size is None:
65
+ assert hp.frame_shift_ms is not None
66
+ hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
67
+ return hop_size
68
+
69
+ def linearspectrogram(wav):
70
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
71
+ S = _amp_to_db(np.abs(D)) - hp.ref_level_db
72
+
73
+ if hp.signal_normalization:
74
+ return _normalize(S)
75
+ return S
76
+
77
+ def melspectrogram(wav):
78
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
79
+ S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
80
+
81
+ if hp.signal_normalization:
82
+ return _normalize(S)
83
+ return S
84
+
85
+ def _lws_processor():
86
+ import lws
87
+ return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
88
+
89
+ def _stft(y):
90
+ if hp.use_lws:
91
+ return _lws_processor(hp).stft(y).T
92
+ else:
93
+ return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
94
+
95
+ ##########################################################
96
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
97
+ def num_frames(length, fsize, fshift):
98
+ """Compute number of time frames of spectrogram
99
+ """
100
+ pad = (fsize - fshift)
101
+ if length % fshift == 0:
102
+ M = (length + pad * 2 - fsize) // fshift + 1
103
+ else:
104
+ M = (length + pad * 2 - fsize) // fshift + 2
105
+ return M
106
+
107
+
108
+ def pad_lr(x, fsize, fshift):
109
+ """Compute left and right padding
110
+ """
111
+ M = num_frames(len(x), fsize, fshift)
112
+ pad = (fsize - fshift)
113
+ T = len(x) + 2 * pad
114
+ r = (M - 1) * fshift + fsize - T
115
+ return pad, pad + r
116
+ ##########################################################
117
+ #Librosa correct padding
118
+ def librosa_pad_lr(x, fsize, fshift):
119
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
120
+
121
+ # Conversions
122
+ _mel_basis = None
123
+
124
+ def _linear_to_mel(spectogram):
125
+ global _mel_basis
126
+ if _mel_basis is None:
127
+ _mel_basis = _build_mel_basis()
128
+ return np.dot(_mel_basis, spectogram)
129
+
130
+ def _build_mel_basis():
131
+ assert hp.fmax <= hp.sample_rate // 2
132
+ return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
133
+ fmin=hp.fmin, fmax=hp.fmax)
134
+
135
+ def _amp_to_db(x):
136
+ min_level = np.exp(hp.min_level_db / 20 * np.log(10))
137
+ return 20 * np.log10(np.maximum(min_level, x))
138
+
139
+ def _db_to_amp(x):
140
+ return np.power(10.0, (x) * 0.05)
141
+
142
+ def _normalize(S):
143
+ if hp.allow_clipping_in_normalization:
144
+ if hp.symmetric_mels:
145
+ return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
146
+ -hp.max_abs_value, hp.max_abs_value)
147
+ else:
148
+ return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
149
+
150
+ assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
151
+ if hp.symmetric_mels:
152
+ return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
153
+ else:
154
+ return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
155
+
156
+ def _denormalize(D):
157
+ if hp.allow_clipping_in_normalization:
158
+ if hp.symmetric_mels:
159
+ return (((np.clip(D, -hp.max_abs_value,
160
+ hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
161
+ + hp.min_level_db)
162
+ else:
163
+ return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
164
+
165
+ if hp.symmetric_mels:
166
+ return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
167
+ else:
168
+ return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
musetalk/data/dataset.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import random
4
+ from PIL import Image
5
+ import torch
6
+ from torch.utils.data import Dataset, ConcatDataset
7
+ import torchvision.transforms as transforms
8
+ from transformers import AutoFeatureExtractor
9
+ import librosa
10
+ import time
11
+ import json
12
+ import math
13
+ from decord import AudioReader, VideoReader
14
+ from decord.ndarray import cpu
15
+
16
+ from musetalk.data.sample_method import get_src_idx, shift_landmarks_to_face_coordinates, resize_landmark
17
+ from musetalk.data import audio
18
+
19
+ syncnet_mel_step_size = math.ceil(16 / 5 * 16) # latentsync
20
+
21
+
22
+ class FaceDataset(Dataset):
23
+ """Dataset class for loading and processing video data
24
+
25
+ Each video can be represented as:
26
+ - Concatenated frame images
27
+ - '.mp4' or '.gif' files
28
+ - Folder containing all frames
29
+ """
30
+ def __init__(self,
31
+ cfg,
32
+ list_paths,
33
+ root_path='./dataset/',
34
+ repeats=None):
35
+ # Initialize dataset paths
36
+ meta_paths = []
37
+ if repeats is None:
38
+ repeats = [1] * len(list_paths)
39
+ assert len(repeats) == len(list_paths)
40
+
41
+ # Load data list
42
+ for list_path, repeat_time in zip(list_paths, repeats):
43
+ with open(list_path, 'r') as f:
44
+ num = 0
45
+ f.readline() # Skip header line
46
+ for line in f.readlines():
47
+ line_info = line.strip()
48
+ meta = line_info.split()
49
+ meta = meta[0]
50
+ meta_paths.extend([os.path.join(root_path, meta)] * repeat_time)
51
+ num += 1
52
+ print(f'{list_path}: {num} x {repeat_time} = {num * repeat_time} samples')
53
+
54
+ # Set basic attributes
55
+ self.meta_paths = meta_paths
56
+ self.root_path = root_path
57
+ self.image_size = cfg['image_size']
58
+ self.min_face_size = cfg['min_face_size']
59
+ self.T = cfg['T']
60
+ self.sample_method = cfg['sample_method']
61
+ self.top_k_ratio = cfg['top_k_ratio']
62
+ self.max_attempts = 200
63
+ self.padding_pixel_mouth = cfg['padding_pixel_mouth']
64
+
65
+ # Cropping related parameters
66
+ self.crop_type = cfg['crop_type']
67
+ self.jaw2edge_margin_mean = cfg['cropping_jaw2edge_margin_mean']
68
+ self.jaw2edge_margin_std = cfg['cropping_jaw2edge_margin_std']
69
+ self.random_margin_method = cfg['random_margin_method']
70
+
71
+ # Image transformations
72
+ self.to_tensor = transforms.Compose([
73
+ transforms.ToTensor(),
74
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
75
+ ])
76
+ self.pose_to_tensor = transforms.Compose([
77
+ transforms.ToTensor(),
78
+ ])
79
+
80
+ # Feature extractor
81
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(cfg['whisper_path'])
82
+ self.contorl_face_min_size = cfg["contorl_face_min_size"]
83
+
84
+ print("The sample method is: ", self.sample_method)
85
+ print(f"only use face size > {self.min_face_size}", self.contorl_face_min_size)
86
+
87
+ def generate_random_value(self):
88
+ """Generate random value
89
+
90
+ Returns:
91
+ float: Generated random value
92
+ """
93
+ if self.random_margin_method == "uniform":
94
+ random_value = np.random.uniform(
95
+ self.jaw2edge_margin_mean - self.jaw2edge_margin_std,
96
+ self.jaw2edge_margin_mean + self.jaw2edge_margin_std
97
+ )
98
+ elif self.random_margin_method == "normal":
99
+ random_value = np.random.normal(
100
+ loc=self.jaw2edge_margin_mean,
101
+ scale=self.jaw2edge_margin_std
102
+ )
103
+ random_value = np.clip(
104
+ random_value,
105
+ self.jaw2edge_margin_mean - self.jaw2edge_margin_std,
106
+ self.jaw2edge_margin_mean + self.jaw2edge_margin_std,
107
+ )
108
+ else:
109
+ raise ValueError(f"Invalid random margin method: {self.random_margin_method}")
110
+ return max(0, random_value)
111
+
112
+ def dynamic_margin_crop(self, img, original_bbox, extra_margin=None):
113
+ """Dynamically crop image with dynamic margin
114
+
115
+ Args:
116
+ img: Input image
117
+ original_bbox: Original bounding box
118
+ extra_margin: Extra margin
119
+
120
+ Returns:
121
+ tuple: (x1, y1, x2, y2, extra_margin)
122
+ """
123
+ if extra_margin is None:
124
+ extra_margin = self.generate_random_value()
125
+ w, h = img.size
126
+ x1, y1, x2, y2 = original_bbox
127
+ y2 = min(y2 + int(extra_margin), h)
128
+ return x1, y1, x2, y2, extra_margin
129
+
130
+ def crop_resize_img(self, img, bbox, crop_type='crop_resize', extra_margin=None):
131
+ """Crop and resize image
132
+
133
+ Args:
134
+ img: Input image
135
+ bbox: Bounding box
136
+ crop_type: Type of cropping
137
+ extra_margin: Extra margin
138
+
139
+ Returns:
140
+ tuple: (Processed image, extra_margin, mask_scaled_factor)
141
+ """
142
+ mask_scaled_factor = 1.
143
+ if crop_type == 'crop_resize':
144
+ x1, y1, x2, y2 = bbox
145
+ img = img.crop((x1, y1, x2, y2))
146
+ img = img.resize((self.image_size, self.image_size), Image.LANCZOS)
147
+ elif crop_type == 'dynamic_margin_crop_resize':
148
+ x1, y1, x2, y2, extra_margin = self.dynamic_margin_crop(img, bbox, extra_margin)
149
+ w_original, _ = img.size
150
+ img = img.crop((x1, y1, x2, y2))
151
+ w_cropped, _ = img.size
152
+ mask_scaled_factor = w_cropped / w_original
153
+ img = img.resize((self.image_size, self.image_size), Image.LANCZOS)
154
+ elif crop_type == 'resize':
155
+ w, h = img.size
156
+ scale = np.sqrt(self.image_size ** 2 / (h * w))
157
+ new_w = int(w * scale) / 64 * 64
158
+ new_h = int(h * scale) / 64 * 64
159
+ img = img.resize((new_w, new_h), Image.LANCZOS)
160
+ return img, extra_margin, mask_scaled_factor
161
+
162
+ def get_audio_file(self, wav_path, start_index):
163
+ """Get audio file features
164
+
165
+ Args:
166
+ wav_path: Audio file path
167
+ start_index: Starting index
168
+
169
+ Returns:
170
+ tuple: (Audio features, start index)
171
+ """
172
+ if not os.path.exists(wav_path):
173
+ return None
174
+ audio_input_librosa, sampling_rate = librosa.load(wav_path, sr=16000)
175
+ assert sampling_rate == 16000
176
+
177
+ while start_index >= 25 * 30:
178
+ audio_input = audio_input_librosa[16000*30:]
179
+ start_index -= 25 * 30
180
+ if start_index + 2 * 25 >= 25 * 30:
181
+ start_index -= 4 * 25
182
+ audio_input = audio_input_librosa[16000*4:16000*34]
183
+ else:
184
+ audio_input = audio_input_librosa[:16000*30]
185
+
186
+ assert 2 * (start_index) >= 0
187
+ assert 2 * (start_index + 2 * 25) <= 1500
188
+
189
+ audio_input = self.feature_extractor(
190
+ audio_input,
191
+ return_tensors="pt",
192
+ sampling_rate=sampling_rate
193
+ ).input_features
194
+ return audio_input, start_index
195
+
196
+ def get_audio_file_mel(self, wav_path, start_index):
197
+ """Get mel spectrogram of audio file
198
+
199
+ Args:
200
+ wav_path: Audio file path
201
+ start_index: Starting index
202
+
203
+ Returns:
204
+ tuple: (Mel spectrogram, start index)
205
+ """
206
+ if not os.path.exists(wav_path):
207
+ return None
208
+
209
+ audio_input, sampling_rate = librosa.load(wav_path, sr=16000)
210
+ assert sampling_rate == 16000
211
+
212
+ audio_input = self.mel_feature_extractor(audio_input)
213
+ return audio_input, start_index
214
+
215
+ def mel_feature_extractor(self, audio_input):
216
+ """Extract mel spectrogram features
217
+
218
+ Args:
219
+ audio_input: Input audio
220
+
221
+ Returns:
222
+ ndarray: Mel spectrogram features
223
+ """
224
+ orig_mel = audio.melspectrogram(audio_input)
225
+ return orig_mel.T
226
+
227
+ def crop_audio_window(self, spec, start_frame_num, fps=25):
228
+ """Crop audio window
229
+
230
+ Args:
231
+ spec: Spectrogram
232
+ start_frame_num: Starting frame number
233
+ fps: Frames per second
234
+
235
+ Returns:
236
+ ndarray: Cropped spectrogram
237
+ """
238
+ start_idx = int(80. * (start_frame_num / float(fps)))
239
+ end_idx = start_idx + syncnet_mel_step_size
240
+ return spec[start_idx: end_idx, :]
241
+
242
+ def get_syncnet_input(self, video_path):
243
+ """Get SyncNet input features
244
+
245
+ Args:
246
+ video_path: Video file path
247
+
248
+ Returns:
249
+ ndarray: SyncNet input features
250
+ """
251
+ ar = AudioReader(video_path, sample_rate=16000)
252
+ original_mel = audio.melspectrogram(ar[:].asnumpy().squeeze(0))
253
+ return original_mel.T
254
+
255
+ def get_resized_mouth_mask(
256
+ self,
257
+ img_resized,
258
+ landmark_array,
259
+ face_shape,
260
+ padding_pixel_mouth=0,
261
+ image_size=256,
262
+ crop_margin=0
263
+ ):
264
+ landmark_array = np.array(landmark_array)
265
+ resized_landmark = resize_landmark(
266
+ landmark_array, w=face_shape[0], h=face_shape[1], new_w=image_size, new_h=image_size)
267
+
268
+ landmark_array = np.array(resized_landmark[48 : 67]) # the lip landmarks in 68 landmarks format
269
+ min_x, min_y = np.min(landmark_array, axis=0)
270
+ max_x, max_y = np.max(landmark_array, axis=0)
271
+ min_x = min_x - padding_pixel_mouth
272
+ max_x = max_x + padding_pixel_mouth
273
+
274
+ # Calculate x-axis length and use it for y-axis
275
+ width = max_x - min_x
276
+
277
+ # Calculate old center point
278
+ center_y = (max_y + min_y) / 2
279
+
280
+ # Determine new min_y and max_y based on width
281
+ min_y = center_y - width / 4
282
+ max_y = center_y + width / 4
283
+
284
+ # Adjust mask position for dynamic crop, shift y-axis
285
+ min_y = min_y - crop_margin
286
+ max_y = max_y - crop_margin
287
+
288
+ # Prevent out of bounds
289
+ min_x = max(min_x, 0)
290
+ min_y = max(min_y, 0)
291
+ max_x = min(max_x, face_shape[0])
292
+ max_y = min(max_y, face_shape[1])
293
+
294
+ mask = np.zeros_like(np.array(img_resized))
295
+ mask[round(min_y):round(max_y), round(min_x):round(max_x)] = 255
296
+ return Image.fromarray(mask)
297
+
298
+ def __len__(self):
299
+ return 100000
300
+
301
+ def __getitem__(self, idx):
302
+ attempts = 0
303
+ while attempts < self.max_attempts:
304
+ try:
305
+ meta_path = random.sample(self.meta_paths, k=1)[0]
306
+ with open(meta_path, 'r') as f:
307
+ meta_data = json.load(f)
308
+ except Exception as e:
309
+ print(f"meta file error:{meta_path}")
310
+ print(e)
311
+ attempts += 1
312
+ time.sleep(0.1)
313
+ continue
314
+
315
+ video_path = meta_data["mp4_path"]
316
+ wav_path = meta_data["wav_path"]
317
+ bbox_list = meta_data["face_list"]
318
+ landmark_list = meta_data["landmark_list"]
319
+ T = self.T
320
+
321
+ s = 0
322
+ e = meta_data["frames"]
323
+ len_valid_clip = e - s
324
+
325
+ if len_valid_clip < T * 10:
326
+ attempts += 1
327
+ print(f"video {video_path} has less than {T * 10} frames")
328
+ continue
329
+
330
+ try:
331
+ cap = VideoReader(video_path, fault_tol=1, ctx=cpu(0))
332
+ total_frames = len(cap)
333
+ assert total_frames == len(landmark_list)
334
+ assert total_frames == len(bbox_list)
335
+ landmark_shape = np.array(landmark_list).shape
336
+ if landmark_shape != (total_frames, 68, 2):
337
+ attempts += 1
338
+ print(f"video {video_path} has invalid landmark shape: {landmark_shape}, expected: {(total_frames, 68, 2)}") # we use 68 landmarks
339
+ continue
340
+ except Exception as e:
341
+ print(f"video file error:{video_path}")
342
+ print(e)
343
+ attempts += 1
344
+ time.sleep(0.1)
345
+ continue
346
+
347
+ shift_landmarks, bbox_list_union, face_shapes = shift_landmarks_to_face_coordinates(
348
+ landmark_list,
349
+ bbox_list
350
+ )
351
+ if self.contorl_face_min_size and face_shapes[0][0] < self.min_face_size:
352
+ print(f"video {video_path} has face size {face_shapes[0][0]} less than minimum required {self.min_face_size}")
353
+ attempts += 1
354
+ continue
355
+
356
+ step = 1
357
+ drive_idx_start = random.randint(s, e - T * step)
358
+ drive_idx_list = list(
359
+ range(drive_idx_start, drive_idx_start + T * step, step))
360
+ assert len(drive_idx_list) == T
361
+
362
+ src_idx_list = []
363
+ list_index_out_of_range = False
364
+ for drive_idx in drive_idx_list:
365
+ src_idx = get_src_idx(
366
+ drive_idx, T, self.sample_method, shift_landmarks, face_shapes, self.top_k_ratio)
367
+ if src_idx is None:
368
+ list_index_out_of_range = True
369
+ break
370
+ src_idx = min(src_idx, e - 1)
371
+ src_idx = max(src_idx, s)
372
+ src_idx_list.append(src_idx)
373
+
374
+ if list_index_out_of_range:
375
+ attempts += 1
376
+ print(f"video {video_path} has invalid source index for drive frames")
377
+ continue
378
+
379
+ ref_face_valid_flag = True
380
+ extra_margin = self.generate_random_value()
381
+
382
+ # Get reference images
383
+ ref_imgs = []
384
+ for src_idx in src_idx_list:
385
+ imSrc = Image.fromarray(cap[src_idx].asnumpy())
386
+ bbox_s = bbox_list_union[src_idx]
387
+ imSrc, _, _ = self.crop_resize_img(
388
+ imSrc,
389
+ bbox_s,
390
+ self.crop_type,
391
+ extra_margin=None
392
+ )
393
+ if self.contorl_face_min_size and min(imSrc.size[0], imSrc.size[1]) < self.min_face_size:
394
+ ref_face_valid_flag = False
395
+ break
396
+ ref_imgs.append(imSrc)
397
+
398
+ if not ref_face_valid_flag:
399
+ attempts += 1
400
+ print(f"video {video_path} has reference face size smaller than minimum required {self.min_face_size}")
401
+ continue
402
+
403
+ # Get target images and masks
404
+ imSameIDs = []
405
+ bboxes = []
406
+ face_masks = []
407
+ face_mask_valid = True
408
+ target_face_valid_flag = True
409
+
410
+ for drive_idx in drive_idx_list:
411
+ imSameID = Image.fromarray(cap[drive_idx].asnumpy())
412
+ bbox_s = bbox_list_union[drive_idx]
413
+ imSameID, _ , mask_scaled_factor = self.crop_resize_img(
414
+ imSameID,
415
+ bbox_s,
416
+ self.crop_type,
417
+ extra_margin=extra_margin
418
+ )
419
+ if self.contorl_face_min_size and min(imSameID.size[0], imSameID.size[1]) < self.min_face_size:
420
+ target_face_valid_flag = False
421
+ break
422
+ crop_margin = extra_margin * mask_scaled_factor
423
+ face_mask = self.get_resized_mouth_mask(
424
+ imSameID,
425
+ shift_landmarks[drive_idx],
426
+ face_shapes[drive_idx],
427
+ self.padding_pixel_mouth,
428
+ self.image_size,
429
+ crop_margin=crop_margin
430
+ )
431
+ if np.count_nonzero(face_mask) == 0:
432
+ face_mask_valid = False
433
+ break
434
+
435
+ if face_mask.size[1] == 0 or face_mask.size[0] == 0:
436
+ print(f"video {video_path} has invalid face mask size at frame {drive_idx}")
437
+ face_mask_valid = False
438
+ break
439
+
440
+ imSameIDs.append(imSameID)
441
+ bboxes.append(bbox_s)
442
+ face_masks.append(face_mask)
443
+
444
+ if not face_mask_valid:
445
+ attempts += 1
446
+ print(f"video {video_path} has invalid face mask")
447
+ continue
448
+
449
+ if not target_face_valid_flag:
450
+ attempts += 1
451
+ print(f"video {video_path} has target face size smaller than minimum required {self.min_face_size}")
452
+ continue
453
+
454
+ # Process audio features
455
+ audio_offset = drive_idx_list[0]
456
+ audio_step = step
457
+ fps = 25.0 / step
458
+
459
+ try:
460
+ audio_feature, audio_offset = self.get_audio_file(wav_path, audio_offset)
461
+ _, audio_offset = self.get_audio_file_mel(wav_path, audio_offset)
462
+ audio_feature_mel = self.get_syncnet_input(video_path)
463
+ except Exception as e:
464
+ print(f"audio file error:{wav_path}")
465
+ print(e)
466
+ attempts += 1
467
+ time.sleep(0.1)
468
+ continue
469
+
470
+ mel = self.crop_audio_window(audio_feature_mel, audio_offset)
471
+ if mel.shape[0] != syncnet_mel_step_size:
472
+ attempts += 1
473
+ print(f"video {video_path} has invalid mel spectrogram shape: {mel.shape}, expected: {syncnet_mel_step_size}")
474
+ continue
475
+
476
+ mel = torch.FloatTensor(mel.T).unsqueeze(0)
477
+
478
+ # Build sample dictionary
479
+ sample = dict(
480
+ pixel_values_vid=torch.stack(
481
+ [self.to_tensor(imSameID) for imSameID in imSameIDs], dim=0),
482
+ pixel_values_ref_img=torch.stack(
483
+ [self.to_tensor(ref_img) for ref_img in ref_imgs], dim=0),
484
+ pixel_values_face_mask=torch.stack(
485
+ [self.pose_to_tensor(face_mask) for face_mask in face_masks], dim=0),
486
+ audio_feature=audio_feature[0],
487
+ audio_offset=audio_offset,
488
+ audio_step=audio_step,
489
+ mel=mel,
490
+ wav_path=wav_path,
491
+ fps=fps,
492
+ )
493
+
494
+ return sample
495
+
496
+ raise ValueError("Unable to find a valid sample after maximum attempts.")
497
+
498
+ class HDTFDataset(FaceDataset):
499
+ """HDTF dataset class"""
500
+ def __init__(self, cfg):
501
+ root_path = './dataset/HDTF/meta'
502
+ list_paths = [
503
+ './dataset/HDTF/train.txt',
504
+ ]
505
+
506
+
507
+ repeats = [10]
508
+ super().__init__(cfg, list_paths, root_path, repeats)
509
+ print('HDTFDataset: ', len(self))
510
+
511
+ class VFHQDataset(FaceDataset):
512
+ """VFHQ dataset class"""
513
+ def __init__(self, cfg):
514
+ root_path = './dataset/VFHQ/meta'
515
+ list_paths = [
516
+ './dataset/VFHQ/train.txt',
517
+ ]
518
+ repeats = [1]
519
+ super().__init__(cfg, list_paths, root_path, repeats)
520
+ print('VFHQDataset: ', len(self))
521
+
522
+ def PortraitDataset(cfg=None):
523
+ """Return dataset based on configuration
524
+
525
+ Args:
526
+ cfg: Configuration dictionary
527
+
528
+ Returns:
529
+ Dataset: Combined dataset
530
+ """
531
+ if cfg["dataset_key"] == "HDTF":
532
+ return ConcatDataset([HDTFDataset(cfg)])
533
+ elif cfg["dataset_key"] == "VFHQ":
534
+ return ConcatDataset([VFHQDataset(cfg)])
535
+ else:
536
+ print("############ use all dataset ############ ")
537
+ return ConcatDataset([HDTFDataset(cfg), VFHQDataset(cfg)])
538
+
539
+
540
+ if __name__ == '__main__':
541
+ # Set random seeds for reproducibility
542
+ seed = 42
543
+ random.seed(seed)
544
+ np.random.seed(seed)
545
+ torch.manual_seed(seed)
546
+ torch.cuda.manual_seed(seed)
547
+ torch.cuda.manual_seed_all(seed)
548
+
549
+ # Create dataset with configuration parameters
550
+ dataset = PortraitDataset(cfg={
551
+ 'T': 1, # Number of frames to process at once
552
+ 'random_margin_method': "normal", # Method for generating random margins: "normal" or "uniform"
553
+ 'dataset_key': "HDTF", # Dataset to use: "HDTF", "VFHQ", or None for both
554
+ 'image_size': 256, # Size of processed images (height and width)
555
+ 'sample_method': 'pose_similarity_and_mouth_dissimilarity', # Method for selecting reference frames
556
+ 'top_k_ratio': 0.51, # Ratio for top-k selection in reference frame sampling
557
+ 'contorl_face_min_size': True, # Whether to enforce minimum face size
558
+ 'padding_pixel_mouth': 10, # Padding pixels around mouth region in mask
559
+ 'min_face_size': 200, # Minimum face size requirement for dataset
560
+ 'whisper_path': "./models/whisper", # Path to Whisper model
561
+ 'cropping_jaw2edge_margin_mean': 10, # Mean margin for jaw-to-edge cropping
562
+ 'cropping_jaw2edge_margin_std': 10, # Standard deviation for jaw-to-edge cropping
563
+ 'crop_type': "dynamic_margin_crop_resize", # Type of cropping: "crop_resize", "dynamic_margin_crop_resize", or "resize"
564
+ })
565
+ print(len(dataset))
566
+
567
+ import torchvision
568
+ os.makedirs('debug', exist_ok=True)
569
+ for i in range(10): # Check 10 samples
570
+ sample = dataset[0]
571
+ print(f"processing {i}")
572
+
573
+ # Get images and mask
574
+ ref_img = (sample['pixel_values_ref_img'] + 1.0) / 2 # (b, c, h, w)
575
+ target_img = (sample['pixel_values_vid'] + 1.0) / 2
576
+ face_mask = sample['pixel_values_face_mask']
577
+
578
+ # Print dimension information
579
+ print(f"ref_img shape: {ref_img.shape}")
580
+ print(f"target_img shape: {target_img.shape}")
581
+ print(f"face_mask shape: {face_mask.shape}")
582
+
583
+ # Create visualization images
584
+ b, c, h, w = ref_img.shape
585
+
586
+ # Apply mask only to target image
587
+ target_mask = face_mask
588
+
589
+ # Keep reference image unchanged
590
+ ref_with_mask = ref_img.clone()
591
+
592
+ # Create mask overlay for target image
593
+ target_with_mask = target_img.clone()
594
+ target_with_mask = target_with_mask * (1 - target_mask) + target_mask # Apply mask only to target
595
+
596
+ # Save original images, mask, and overlay results
597
+ # First row: original images
598
+ # Second row: mask
599
+ # Third row: overlay effect
600
+ concatenated_img = torch.cat((
601
+ ref_img, target_img, # Original images
602
+ torch.zeros_like(ref_img), target_mask, # Mask (black for ref)
603
+ ref_with_mask, target_with_mask # Overlay effect
604
+ ), dim=3)
605
+
606
+ torchvision.utils.save_image(
607
+ concatenated_img, f'debug/mask_check_{i}.jpg', nrow=2)
musetalk/data/sample_method.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+
4
+ def summarize_tensor(x):
5
+ return f"\033[34m{str(tuple(x.shape)).ljust(24)}\033[0m (\033[31mmin {x.min().item():+.4f}\033[0m / \033[32mmean {x.mean().item():+.4f}\033[0m / \033[33mmax {x.max().item():+.4f}\033[0m)"
6
+
7
+ def calculate_mouth_open_similarity(landmarks_list, select_idx,top_k=50,ascending=True):
8
+ num_landmarks = len(landmarks_list)
9
+ mouth_open_ratios = np.zeros(num_landmarks) # Initialize as a numpy array
10
+ print(np.shape(landmarks_list))
11
+ ## Calculate mouth opening ratios
12
+ for i, landmarks in enumerate(landmarks_list):
13
+ # Assuming landmarks are in the format [x, y] and accessible by index
14
+ mouth_top = landmarks[165] # Adjust index according to your landmarks format
15
+ mouth_bottom = landmarks[147] # Adjust index according to your landmarks format
16
+ mouth_open_ratio = np.linalg.norm(mouth_top - mouth_bottom)
17
+ mouth_open_ratios[i] = mouth_open_ratio
18
+
19
+ # Calculate differences matrix
20
+ differences_matrix = np.abs(mouth_open_ratios[:, np.newaxis] - mouth_open_ratios[select_idx])
21
+ differences_matrix_with_signs = mouth_open_ratios[:, np.newaxis] - mouth_open_ratios[select_idx]
22
+ print(differences_matrix.shape)
23
+ # Find top_k similar indices for each landmark set
24
+ if ascending:
25
+ top_indices = np.argsort(differences_matrix[i])[:top_k]
26
+ else:
27
+ top_indices = np.argsort(-differences_matrix[i])[:top_k]
28
+ similar_landmarks_indices = top_indices.tolist()
29
+ similar_landmarks_distances = differences_matrix_with_signs[i].tolist() #注意这里不要排序
30
+
31
+ return similar_landmarks_indices, similar_landmarks_distances
32
+ #############################################################################################
33
+ def get_closed_mouth(landmarks_list,ascending=True,top_k=50):
34
+ num_landmarks = len(landmarks_list)
35
+
36
+ mouth_open_ratios = np.zeros(num_landmarks) # Initialize as a numpy array
37
+ ## Calculate mouth opening ratios
38
+ #print("landmarks shape",np.shape(landmarks_list))
39
+ for i, landmarks in enumerate(landmarks_list):
40
+ # Assuming landmarks are in the format [x, y] and accessible by index
41
+ #print(landmarks[165])
42
+ mouth_top = np.array(landmarks[165])# Adjust index according to your landmarks format
43
+ mouth_bottom = np.array(landmarks[147]) # Adjust index according to your landmarks format
44
+ mouth_open_ratio = np.linalg.norm(mouth_top - mouth_bottom)
45
+ mouth_open_ratios[i] = mouth_open_ratio
46
+
47
+ # Find top_k similar indices for each landmark set
48
+ if ascending:
49
+ top_indices = np.argsort(mouth_open_ratios)[:top_k]
50
+ else:
51
+ top_indices = np.argsort(-mouth_open_ratios)[:top_k]
52
+ return top_indices
53
+
54
+ def calculate_landmarks_similarity(selected_idx, landmarks_list,image_shapes, start_index, end_index, top_k=50,ascending=True):
55
+ """
56
+ Calculate the similarity between sets of facial landmarks and return the indices of the most similar faces.
57
+
58
+ Parameters:
59
+ landmarks_list (list): A list containing sets of facial landmarks, each element is a set of landmarks.
60
+ image_shapes (list): A list containing the shape of each image, each element is a (width, height) tuple.
61
+ start_index (int): The starting index of the facial landmarks.
62
+ end_index (int): The ending index of the facial landmarks.
63
+ top_k (int): The number of most similar landmark sets to return. Default is 50.
64
+ ascending (bool): Controls the sorting order. If True, sort in ascending order; If False, sort in descending order. Default is True.
65
+
66
+ Returns:
67
+ similar_landmarks_indices (list): A list containing the indices of the most similar facial landmarks for each face.
68
+ resized_landmarks (list): A list containing the resized facial landmarks.
69
+ """
70
+ num_landmarks = len(landmarks_list)
71
+ resized_landmarks = []
72
+
73
+ # Preprocess landmarks
74
+ for i in range(num_landmarks):
75
+ landmark_array = np.array(landmarks_list[i])
76
+ selected_landmarks = landmark_array[start_index:end_index]
77
+ resized_landmark = resize_landmark(selected_landmarks, w=image_shapes[i][0], h=image_shapes[i][1],new_w=256,new_h=256)
78
+ resized_landmarks.append(resized_landmark)
79
+
80
+ resized_landmarks_array = np.array(resized_landmarks) # Convert list to array for easier manipulation
81
+
82
+ # Calculate similarity
83
+ distances = np.linalg.norm(resized_landmarks_array - resized_landmarks_array[selected_idx][np.newaxis, :], axis=2)
84
+ overall_distances = np.mean(distances, axis=1) # Calculate mean distance for each set of landmarks
85
+
86
+ if ascending:
87
+ sorted_indices = np.argsort(overall_distances)
88
+ similar_landmarks_indices = sorted_indices[1:top_k+1].tolist() # Exclude self and take top_k
89
+ else:
90
+ sorted_indices = np.argsort(-overall_distances)
91
+ similar_landmarks_indices = sorted_indices[0:top_k].tolist()
92
+
93
+ return similar_landmarks_indices
94
+
95
+ def process_bbox_musetalk(face_array, landmark_array):
96
+ x_min_face, y_min_face, x_max_face, y_max_face = map(int, face_array)
97
+ x_min_lm = min([int(x) for x, y in landmark_array])
98
+ y_min_lm = min([int(y) for x, y in landmark_array])
99
+ x_max_lm = max([int(x) for x, y in landmark_array])
100
+ y_max_lm = max([int(y) for x, y in landmark_array])
101
+ x_min = min(x_min_face, x_min_lm)
102
+ y_min = min(y_min_face, y_min_lm)
103
+ x_max = max(x_max_face, x_max_lm)
104
+ y_max = max(y_max_face, y_max_lm)
105
+
106
+ x_min = max(x_min, 0)
107
+ y_min = max(y_min, 0)
108
+
109
+ return [x_min, y_min, x_max, y_max]
110
+
111
+ def shift_landmarks_to_face_coordinates(landmark_list, face_list):
112
+ """
113
+ Translates the data in landmark_list to the coordinates of the cropped larger face.
114
+
115
+ Parameters:
116
+ landmark_list (list): A list containing multiple sets of facial landmarks.
117
+ face_list (list): A list containing multiple facial images.
118
+
119
+ Returns:
120
+ landmark_list_shift (list): The list of translated landmarks.
121
+ bbox_union (list): The list of union bounding boxes.
122
+ face_shapes (list): The list of facial shapes.
123
+ """
124
+ landmark_list_shift = []
125
+ bbox_union = []
126
+ face_shapes = []
127
+
128
+ for i in range(len(face_list)):
129
+ landmark_array = np.array(landmark_list[i]) # 转换为numpy数组并创建副本
130
+ face_array = face_list[i]
131
+ f_landmark_bbox = process_bbox_musetalk(face_array, landmark_array)
132
+ x_min, y_min, x_max, y_max = f_landmark_bbox
133
+ landmark_array[:, 0] = landmark_array[:, 0] - f_landmark_bbox[0]
134
+ landmark_array[:, 1] = landmark_array[:, 1] - f_landmark_bbox[1]
135
+ landmark_list_shift.append(landmark_array)
136
+ bbox_union.append(f_landmark_bbox)
137
+ face_shapes.append((x_max - x_min, y_max - y_min))
138
+
139
+ return landmark_list_shift, bbox_union, face_shapes
140
+
141
+ def resize_landmark(landmark, w, h, new_w, new_h):
142
+ landmark_norm = landmark / [w, h]
143
+ landmark_resized = landmark_norm * [new_w, new_h]
144
+
145
+ return landmark_resized
146
+
147
+ def get_src_idx(drive_idx, T, sample_method,landmarks_list,image_shapes,top_k_ratio):
148
+ """
149
+ Calculate the source index (src_idx) based on the given drive index, T, s, e, and sampling method.
150
+
151
+ Parameters:
152
+ - drive_idx (int): The current drive index.
153
+ - T (int): Total number of frames or a specific range limit.
154
+ - sample_method (str): Sampling method, which can be "random" or other methods.
155
+ - landmarks_list (list): List of facial landmarks.
156
+ - image_shapes (list): List of image shapes.
157
+ - top_k_ratio (float): Ratio for selecting top k similar frames.
158
+
159
+ Returns:
160
+ - src_idx (int): The calculated source index.
161
+ """
162
+ if sample_method == "random":
163
+ src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
164
+ elif sample_method == "pose_similarity":
165
+ top_k = int(top_k_ratio*len(landmarks_list))
166
+ try:
167
+ top_k = int(top_k_ratio*len(landmarks_list))
168
+ # facial contour
169
+ landmark_start_idx = 0
170
+ landmark_end_idx = 16
171
+ pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
172
+ src_idx = random.choice(pose_similarity_list)
173
+ while abs(src_idx-drive_idx)<5:
174
+ src_idx = random.choice(pose_similarity_list)
175
+ except Exception as e:
176
+ print(e)
177
+ return None
178
+ elif sample_method=="pose_similarity_and_closed_mouth":
179
+ # facial contour
180
+ landmark_start_idx = 0
181
+ landmark_end_idx = 16
182
+ try:
183
+ top_k = int(top_k_ratio*len(landmarks_list))
184
+ closed_mouth_list = get_closed_mouth(landmarks_list, ascending=True,top_k=top_k)
185
+ #print("closed_mouth_list",closed_mouth_list)
186
+ pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
187
+ #print("pose_similarity_list",pose_similarity_list)
188
+ common_list = list(set(closed_mouth_list).intersection(set(pose_similarity_list)))
189
+ if len(common_list) == 0:
190
+ src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
191
+ else:
192
+ src_idx = random.choice(common_list)
193
+
194
+ while abs(src_idx-drive_idx) <5:
195
+ src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
196
+
197
+ except Exception as e:
198
+ print(e)
199
+ return None
200
+
201
+ elif sample_method=="pose_similarity_and_mouth_dissimilarity":
202
+ top_k = int(top_k_ratio*len(landmarks_list))
203
+ try:
204
+ top_k = int(top_k_ratio*len(landmarks_list))
205
+
206
+ # facial contour for 68 landmarks format
207
+ landmark_start_idx = 0
208
+ landmark_end_idx = 16
209
+
210
+ pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
211
+
212
+ # Mouth inner coutour for 68 landmarks format
213
+ landmark_start_idx = 60
214
+ landmark_end_idx = 67
215
+
216
+ mouth_dissimilarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=False)
217
+
218
+ common_list = list(set(pose_similarity_list).intersection(set(mouth_dissimilarity_list)))
219
+ if len(common_list) == 0:
220
+ src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
221
+ else:
222
+ src_idx = random.choice(common_list)
223
+
224
+ while abs(src_idx-drive_idx) <5:
225
+ src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
226
+
227
+ except Exception as e:
228
+ print(e)
229
+ return None
230
+
231
+ else:
232
+ raise ValueError(f"Unknown sample_method: {sample_method}")
233
+ return src_idx
musetalk/loss/basic_loss.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from omegaconf import OmegaConf
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn, optim
7
+ from torch.optim.lr_scheduler import CosineAnnealingLR
8
+ from musetalk.loss.discriminator import MultiScaleDiscriminator,DiscriminatorFullModel
9
+ import musetalk.loss.vgg_face as vgg_face
10
+
11
+ class Interpolate(nn.Module):
12
+ def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
13
+ super(Interpolate, self).__init__()
14
+ self.size = size
15
+ self.scale_factor = scale_factor
16
+ self.mode = mode
17
+ self.align_corners = align_corners
18
+
19
+ def forward(self, input):
20
+ return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
21
+
22
+ def set_requires_grad(net, requires_grad=False):
23
+ if net is not None:
24
+ for param in net.parameters():
25
+ param.requires_grad = requires_grad
26
+
27
+ if __name__ == "__main__":
28
+ cfg = OmegaConf.load("config/audio_adapter/E7.yaml")
29
+
30
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
+ pyramid_scale = [1, 0.5, 0.25, 0.125]
32
+ vgg_IN = vgg_face.Vgg19().to(device)
33
+ pyramid = vgg_face.ImagePyramide(cfg.loss_params.pyramid_scale, 3).to(device)
34
+ vgg_IN.eval()
35
+ downsampler = Interpolate(size=(224, 224), mode='bilinear', align_corners=False)
36
+
37
+ image = torch.rand(8, 3, 256, 256).to(device)
38
+ image_pred = torch.rand(8, 3, 256, 256).to(device)
39
+ pyramide_real = pyramid(downsampler(image))
40
+ pyramide_generated = pyramid(downsampler(image_pred))
41
+
42
+
43
+ loss_IN = 0
44
+ for scale in cfg.loss_params.pyramid_scale:
45
+ x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)])
46
+ y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)])
47
+ for i, weight in enumerate(cfg.loss_params.vgg_layer_weight):
48
+ value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
49
+ loss_IN += weight * value
50
+ loss_IN /= sum(cfg.loss_params.vgg_layer_weight) # 对vgg不同层取均值,金字塔loss是每层叠
51
+ print(loss_IN)
52
+
53
+ #print(cfg.model_params.discriminator_params)
54
+
55
+ discriminator = MultiScaleDiscriminator(**cfg.model_params.discriminator_params).to(device)
56
+ discriminator_full = DiscriminatorFullModel(discriminator)
57
+ disc_scales = cfg.model_params.discriminator_params.scales
58
+ # Prepare optimizer and loss function
59
+ optimizer_D = optim.AdamW(discriminator.parameters(),
60
+ lr=cfg.discriminator_train_params.lr,
61
+ weight_decay=cfg.discriminator_train_params.weight_decay,
62
+ betas=cfg.discriminator_train_params.betas,
63
+ eps=cfg.discriminator_train_params.eps)
64
+ scheduler_D = CosineAnnealingLR(optimizer_D,
65
+ T_max=cfg.discriminator_train_params.epochs,
66
+ eta_min=1e-6)
67
+
68
+ discriminator.train()
69
+
70
+ set_requires_grad(discriminator, False)
71
+
72
+ loss_G = 0.
73
+ discriminator_maps_generated = discriminator(pyramide_generated)
74
+ discriminator_maps_real = discriminator(pyramide_real)
75
+
76
+ for scale in disc_scales:
77
+ key = 'prediction_map_%s' % scale
78
+ value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
79
+ loss_G += value
80
+
81
+ print(loss_G)
musetalk/loss/conv.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+
15
+ def forward(self, x):
16
+ out = self.conv_block(x)
17
+ if self.residual:
18
+ out += x
19
+ return self.act(out)
20
+
21
+ class nonorm_Conv2d(nn.Module):
22
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self.conv_block = nn.Sequential(
25
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
26
+ )
27
+ self.act = nn.LeakyReLU(0.01, inplace=True)
28
+
29
+ def forward(self, x):
30
+ out = self.conv_block(x)
31
+ return self.act(out)
32
+
33
+ class Conv2dTranspose(nn.Module):
34
+ def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
35
+ super().__init__(*args, **kwargs)
36
+ self.conv_block = nn.Sequential(
37
+ nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
38
+ nn.BatchNorm2d(cout)
39
+ )
40
+ self.act = nn.ReLU()
41
+
42
+ def forward(self, x):
43
+ out = self.conv_block(x)
44
+ return self.act(out)
musetalk/loss/discriminator.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ from musetalk.loss.vgg_face import ImagePyramide
5
+
6
+ class DownBlock2d(nn.Module):
7
+ """
8
+ Simple block for processing video (encoder).
9
+ """
10
+
11
+ def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
12
+ super(DownBlock2d, self).__init__()
13
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
14
+
15
+ if sn:
16
+ self.conv = nn.utils.spectral_norm(self.conv)
17
+
18
+ if norm:
19
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
20
+ else:
21
+ self.norm = None
22
+ self.pool = pool
23
+
24
+ def forward(self, x):
25
+ out = x
26
+ out = self.conv(out)
27
+ if self.norm:
28
+ out = self.norm(out)
29
+ out = F.leaky_relu(out, 0.2)
30
+ if self.pool:
31
+ out = F.avg_pool2d(out, (2, 2))
32
+ return out
33
+
34
+
35
+ class Discriminator(nn.Module):
36
+ """
37
+ Discriminator similar to Pix2Pix
38
+ """
39
+
40
+ def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
41
+ sn=False, **kwargs):
42
+ super(Discriminator, self).__init__()
43
+
44
+ down_blocks = []
45
+ for i in range(num_blocks):
46
+ down_blocks.append(
47
+ DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)),
48
+ min(max_features, block_expansion * (2 ** (i + 1))),
49
+ norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
50
+
51
+ self.down_blocks = nn.ModuleList(down_blocks)
52
+ self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
53
+ if sn:
54
+ self.conv = nn.utils.spectral_norm(self.conv)
55
+
56
+ def forward(self, x):
57
+ feature_maps = []
58
+ out = x
59
+
60
+ for down_block in self.down_blocks:
61
+ feature_maps.append(down_block(out))
62
+ out = feature_maps[-1]
63
+ prediction_map = self.conv(out)
64
+
65
+ return feature_maps, prediction_map
66
+
67
+
68
+ class MultiScaleDiscriminator(nn.Module):
69
+ """
70
+ Multi-scale (scale) discriminator
71
+ """
72
+
73
+ def __init__(self, scales=(), **kwargs):
74
+ super(MultiScaleDiscriminator, self).__init__()
75
+ self.scales = scales
76
+ discs = {}
77
+ for scale in scales:
78
+ discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
79
+ self.discs = nn.ModuleDict(discs)
80
+
81
+ def forward(self, x):
82
+ out_dict = {}
83
+ for scale, disc in self.discs.items():
84
+ scale = str(scale).replace('-', '.')
85
+ key = 'prediction_' + scale
86
+ #print(key)
87
+ #print(x)
88
+ feature_maps, prediction_map = disc(x[key])
89
+ out_dict['feature_maps_' + scale] = feature_maps
90
+ out_dict['prediction_map_' + scale] = prediction_map
91
+ return out_dict
92
+
93
+
94
+
95
+ class DiscriminatorFullModel(torch.nn.Module):
96
+ """
97
+ Merge all discriminator related updates into single model for better multi-gpu usage
98
+ """
99
+
100
+ def __init__(self, discriminator):
101
+ super(DiscriminatorFullModel, self).__init__()
102
+ self.discriminator = discriminator
103
+ self.scales = self.discriminator.scales
104
+ print("scales",self.scales)
105
+ self.pyramid = ImagePyramide(self.scales, 3)
106
+ if torch.cuda.is_available():
107
+ self.pyramid = self.pyramid.cuda()
108
+
109
+ self.zero_tensor = None
110
+
111
+ def get_zero_tensor(self, input):
112
+ if self.zero_tensor is None:
113
+ self.zero_tensor = torch.FloatTensor(1).fill_(0).cuda()
114
+ self.zero_tensor.requires_grad_(False)
115
+ return self.zero_tensor.expand_as(input)
116
+
117
+ def forward(self, x, generated, gan_mode='ls'):
118
+ pyramide_real = self.pyramid(x)
119
+ pyramide_generated = self.pyramid(generated.detach())
120
+
121
+ discriminator_maps_generated = self.discriminator(pyramide_generated)
122
+ discriminator_maps_real = self.discriminator(pyramide_real)
123
+
124
+ value_total = 0
125
+ for scale in self.scales:
126
+ key = 'prediction_map_%s' % scale
127
+ if gan_mode == 'hinge':
128
+ value = -torch.mean(torch.min(discriminator_maps_real[key]-1, self.get_zero_tensor(discriminator_maps_real[key]))) - torch.mean(torch.min(-discriminator_maps_generated[key]-1, self.get_zero_tensor(discriminator_maps_generated[key])))
129
+ elif gan_mode == 'ls':
130
+ value = ((1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2).mean()
131
+ else:
132
+ raise ValueError('Unexpected gan_mode {}'.format(self.train_params['gan_mode']))
133
+
134
+ value_total += value
135
+
136
+ return value_total
137
+
138
+ def main():
139
+ discriminator = MultiScaleDiscriminator(scales=[1],
140
+ block_expansion=32,
141
+ max_features=512,
142
+ num_blocks=4,
143
+ sn=True,
144
+ image_channel=3,
145
+ estimate_jacobian=False)
musetalk/loss/resnet.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import math
3
+
4
+ __all__ = ['ResNet', 'resnet50']
5
+
6
+ def conv3x3(in_planes, out_planes, stride=1):
7
+ """3x3 convolution with padding"""
8
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
9
+ padding=1, bias=False)
10
+
11
+
12
+ class BasicBlock(nn.Module):
13
+ expansion = 1
14
+
15
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
16
+ super(BasicBlock, self).__init__()
17
+ self.conv1 = conv3x3(inplanes, planes, stride)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.relu = nn.ReLU(inplace=True)
20
+ self.conv2 = conv3x3(planes, planes)
21
+ self.bn2 = nn.BatchNorm2d(planes)
22
+ self.downsample = downsample
23
+ self.stride = stride
24
+
25
+ def forward(self, x):
26
+ residual = x
27
+
28
+ out = self.conv1(x)
29
+ out = self.bn1(out)
30
+ out = self.relu(out)
31
+
32
+ out = self.conv2(out)
33
+ out = self.bn2(out)
34
+
35
+ if self.downsample is not None:
36
+ residual = self.downsample(x)
37
+
38
+ out += residual
39
+ out = self.relu(out)
40
+
41
+ return out
42
+
43
+
44
+ class Bottleneck(nn.Module):
45
+ expansion = 4
46
+
47
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
48
+ super(Bottleneck, self).__init__()
49
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
50
+ self.bn1 = nn.BatchNorm2d(planes)
51
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
52
+ self.bn2 = nn.BatchNorm2d(planes)
53
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
54
+ self.bn3 = nn.BatchNorm2d(planes * 4)
55
+ self.relu = nn.ReLU(inplace=True)
56
+ self.downsample = downsample
57
+ self.stride = stride
58
+
59
+ def forward(self, x):
60
+ residual = x
61
+
62
+ out = self.conv1(x)
63
+ out = self.bn1(out)
64
+ out = self.relu(out)
65
+
66
+ out = self.conv2(out)
67
+ out = self.bn2(out)
68
+ out = self.relu(out)
69
+
70
+ out = self.conv3(out)
71
+ out = self.bn3(out)
72
+
73
+ if self.downsample is not None:
74
+ residual = self.downsample(x)
75
+
76
+ out += residual
77
+ out = self.relu(out)
78
+
79
+ return out
80
+
81
+
82
+ class ResNet(nn.Module):
83
+
84
+ def __init__(self, block, layers, num_classes=1000, include_top=True):
85
+ self.inplanes = 64
86
+ super(ResNet, self).__init__()
87
+ self.include_top = include_top
88
+
89
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
90
+ self.bn1 = nn.BatchNorm2d(64)
91
+ self.relu = nn.ReLU(inplace=True)
92
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True)
93
+
94
+ self.layer1 = self._make_layer(block, 64, layers[0])
95
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
96
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
97
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
98
+ self.avgpool = nn.AvgPool2d(7, stride=1)
99
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
100
+
101
+ for m in self.modules():
102
+ if isinstance(m, nn.Conv2d):
103
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
104
+ m.weight.data.normal_(0, math.sqrt(2. / n))
105
+ elif isinstance(m, nn.BatchNorm2d):
106
+ m.weight.data.fill_(1)
107
+ m.bias.data.zero_()
108
+
109
+ def _make_layer(self, block, planes, blocks, stride=1):
110
+ downsample = None
111
+ if stride != 1 or self.inplanes != planes * block.expansion:
112
+ downsample = nn.Sequential(
113
+ nn.Conv2d(self.inplanes, planes * block.expansion,
114
+ kernel_size=1, stride=stride, bias=False),
115
+ nn.BatchNorm2d(planes * block.expansion),
116
+ )
117
+
118
+ layers = []
119
+ layers.append(block(self.inplanes, planes, stride, downsample))
120
+ self.inplanes = planes * block.expansion
121
+ for i in range(1, blocks):
122
+ layers.append(block(self.inplanes, planes))
123
+
124
+ return nn.Sequential(*layers)
125
+
126
+ def forward(self, x):
127
+ x = x * 255.
128
+ x = x.flip(1)
129
+ x = self.conv1(x)
130
+ x = self.bn1(x)
131
+ x = self.relu(x)
132
+ x = self.maxpool(x)
133
+
134
+ x = self.layer1(x)
135
+ x = self.layer2(x)
136
+ x = self.layer3(x)
137
+ x = self.layer4(x)
138
+
139
+ x = self.avgpool(x)
140
+
141
+ if not self.include_top:
142
+ return x
143
+
144
+ x = x.view(x.size(0), -1)
145
+ x = self.fc(x)
146
+ return x
147
+
148
+ def resnet50(**kwargs):
149
+ """Constructs a ResNet-50 model.
150
+ """
151
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
152
+ return model
musetalk/loss/syncnet.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ from .conv import Conv2d
6
+
7
+ logloss = nn.BCELoss(reduction="none")
8
+ def cosine_loss(a, v, y):
9
+ d = nn.functional.cosine_similarity(a, v)
10
+ d = d.clamp(0,1) # cosine_similarity的取值范围是【-1,1】,BCE如果输入负数会报错RuntimeError: CUDA error: device-side assert triggered
11
+ loss = logloss(d.unsqueeze(1), y).squeeze()
12
+ loss = loss.mean()
13
+ return loss, d
14
+
15
+ def get_sync_loss(
16
+ audio_embed,
17
+ gt_frames,
18
+ pred_frames,
19
+ syncnet,
20
+ adapted_weight,
21
+ frames_left_index=0,
22
+ frames_right_index=16,
23
+ ):
24
+ # 跟gt_frames做随机的插入交换,节省显存开销
25
+ assert pred_frames.shape[1] == (frames_right_index - frames_left_index) * 3
26
+ # 3通道图像
27
+ frames_sync_loss = torch.cat(
28
+ [gt_frames[:, :3 * frames_left_index, ...], pred_frames, gt_frames[:, 3 * frames_right_index:, ...]],
29
+ axis=1
30
+ )
31
+ vision_embed = syncnet.get_image_embed(frames_sync_loss)
32
+ y = torch.ones(frames_sync_loss.size(0), 1).float().to(audio_embed.device)
33
+ loss, score = cosine_loss(audio_embed, vision_embed, y)
34
+ return loss, score
35
+
36
+ class SyncNet_color(nn.Module):
37
+ def __init__(self):
38
+ super(SyncNet_color, self).__init__()
39
+
40
+ self.face_encoder = nn.Sequential(
41
+ Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3),
42
+
43
+ Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1),
44
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
45
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
46
+
47
+ Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
48
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
49
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
50
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
51
+
52
+ Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
53
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
54
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
55
+
56
+ Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
57
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
58
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
59
+
60
+ Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
61
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
62
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
63
+
64
+ self.audio_encoder = nn.Sequential(
65
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
66
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
67
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
68
+
69
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
70
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
71
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
72
+
73
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
74
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
75
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
76
+
77
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
78
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
79
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
80
+
81
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
82
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
83
+
84
+ def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
85
+ face_embedding = self.face_encoder(face_sequences)
86
+ audio_embedding = self.audio_encoder(audio_sequences)
87
+
88
+ audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
89
+ face_embedding = face_embedding.view(face_embedding.size(0), -1)
90
+
91
+ audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
92
+ face_embedding = F.normalize(face_embedding, p=2, dim=1)
93
+
94
+
95
+ return audio_embedding, face_embedding
musetalk/loss/vgg_face.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This part of code contains a pretrained vgg_face model.
3
+ ref link: https://github.com/prlz77/vgg-face.pytorch
4
+ '''
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.utils.model_zoo
8
+ import pickle
9
+ from musetalk.loss import resnet as ResNet
10
+
11
+
12
+ MODEL_URL = "https://github.com/claudio-unipv/vggface-pytorch/releases/download/v0.1/vggface-9d491dd7c30312.pth"
13
+ VGG_FACE_PATH = '/apdcephfs_cq8/share_1367250/zhentaoyu/Driving/00_VASA/00_data/models/pretrain_models/resnet50_ft_weight.pkl'
14
+
15
+ # It was 93.5940, 104.7624, 129.1863 before dividing by 255
16
+ MEAN_RGB = [
17
+ 0.367035294117647,
18
+ 0.41083294117647057,
19
+ 0.5066129411764705
20
+ ]
21
+ def load_state_dict(model, fname):
22
+ """
23
+ Set parameters converted from Caffe models authors of VGGFace2 provide.
24
+ See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/.
25
+
26
+ Arguments:
27
+ model: model
28
+ fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle.
29
+ """
30
+ with open(fname, 'rb') as f:
31
+ weights = pickle.load(f, encoding='latin1')
32
+
33
+ own_state = model.state_dict()
34
+ for name, param in weights.items():
35
+ if name in own_state:
36
+ try:
37
+ own_state[name].copy_(torch.from_numpy(param))
38
+ except Exception:
39
+ raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\
40
+ 'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size()))
41
+ else:
42
+ raise KeyError('unexpected key "{}" in state_dict'.format(name))
43
+
44
+
45
+ def vggface2(pretrained=True):
46
+ vggface = ResNet.resnet50(num_classes=8631, include_top=True)
47
+ load_state_dict(vggface, VGG_FACE_PATH)
48
+ return vggface
49
+
50
+ def vggface(pretrained=False, **kwargs):
51
+ """VGGFace model.
52
+
53
+ Args:
54
+ pretrained (bool): If True, returns pre-trained model
55
+ """
56
+ model = VggFace(**kwargs)
57
+ if pretrained:
58
+ state = torch.utils.model_zoo.load_url(MODEL_URL)
59
+ model.load_state_dict(state)
60
+ return model
61
+
62
+
63
+ class VggFace(torch.nn.Module):
64
+ def __init__(self, classes=2622):
65
+ """VGGFace model.
66
+
67
+ Face recognition network. It takes as input a Bx3x224x224
68
+ batch of face images and gives as output a BxC score vector
69
+ (C is the number of identities).
70
+ Input images need to be scaled in the 0-1 range and then
71
+ normalized with respect to the mean RGB used during training.
72
+
73
+ Args:
74
+ classes (int): number of identities recognized by the
75
+ network
76
+
77
+ """
78
+ super().__init__()
79
+ self.conv1 = _ConvBlock(3, 64, 64)
80
+ self.conv2 = _ConvBlock(64, 128, 128)
81
+ self.conv3 = _ConvBlock(128, 256, 256, 256)
82
+ self.conv4 = _ConvBlock(256, 512, 512, 512)
83
+ self.conv5 = _ConvBlock(512, 512, 512, 512)
84
+ self.dropout = torch.nn.Dropout(0.5)
85
+ self.fc1 = torch.nn.Linear(7 * 7 * 512, 4096)
86
+ self.fc2 = torch.nn.Linear(4096, 4096)
87
+ self.fc3 = torch.nn.Linear(4096, classes)
88
+
89
+ def forward(self, x):
90
+ x = self.conv1(x)
91
+ x = self.conv2(x)
92
+ x = self.conv3(x)
93
+ x = self.conv4(x)
94
+ x = self.conv5(x)
95
+ x = x.view(x.size(0), -1)
96
+ x = self.dropout(F.relu(self.fc1(x)))
97
+ x = self.dropout(F.relu(self.fc2(x)))
98
+ x = self.fc3(x)
99
+ return x
100
+
101
+
102
+ class _ConvBlock(torch.nn.Module):
103
+ """A Convolutional block."""
104
+
105
+ def __init__(self, *units):
106
+ """Create a block with len(units) - 1 convolutions.
107
+
108
+ convolution number i transforms the number of channels from
109
+ units[i - 1] to units[i] channels.
110
+
111
+ """
112
+ super().__init__()
113
+ self.convs = torch.nn.ModuleList([
114
+ torch.nn.Conv2d(in_, out, 3, 1, 1)
115
+ for in_, out in zip(units[:-1], units[1:])
116
+ ])
117
+
118
+ def forward(self, x):
119
+ # Each convolution is followed by a ReLU, then the block is
120
+ # concluded by a max pooling.
121
+ for c in self.convs:
122
+ x = F.relu(c(x))
123
+ return F.max_pool2d(x, 2, 2, 0, ceil_mode=True)
124
+
125
+
126
+
127
+ import numpy as np
128
+ from torchvision import models
129
+ class Vgg19(torch.nn.Module):
130
+ """
131
+ Vgg19 network for perceptual loss.
132
+ """
133
+ def __init__(self, requires_grad=False):
134
+ super(Vgg19, self).__init__()
135
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
136
+ self.slice1 = torch.nn.Sequential()
137
+ self.slice2 = torch.nn.Sequential()
138
+ self.slice3 = torch.nn.Sequential()
139
+ self.slice4 = torch.nn.Sequential()
140
+ self.slice5 = torch.nn.Sequential()
141
+ for x in range(2):
142
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
143
+ for x in range(2, 7):
144
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
145
+ for x in range(7, 12):
146
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
147
+ for x in range(12, 21):
148
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
149
+ for x in range(21, 30):
150
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
151
+
152
+ self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
153
+ requires_grad=False)
154
+ self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
155
+ requires_grad=False)
156
+
157
+ if not requires_grad:
158
+ for param in self.parameters():
159
+ param.requires_grad = False
160
+
161
+ def forward(self, X):
162
+ X = (X - self.mean) / self.std
163
+ h_relu1 = self.slice1(X)
164
+ h_relu2 = self.slice2(h_relu1)
165
+ h_relu3 = self.slice3(h_relu2)
166
+ h_relu4 = self.slice4(h_relu3)
167
+ h_relu5 = self.slice5(h_relu4)
168
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
169
+ return out
170
+
171
+
172
+ from torch import nn
173
+ class AntiAliasInterpolation2d(nn.Module):
174
+ """
175
+ Band-limited downsampling, for better preservation of the input signal.
176
+ """
177
+ def __init__(self, channels, scale):
178
+ super(AntiAliasInterpolation2d, self).__init__()
179
+ sigma = (1 / scale - 1) / 2
180
+ kernel_size = 2 * round(sigma * 4) + 1
181
+ self.ka = kernel_size // 2
182
+ self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
183
+
184
+ kernel_size = [kernel_size, kernel_size]
185
+ sigma = [sigma, sigma]
186
+ # The gaussian kernel is the product of the
187
+ # gaussian function of each dimension.
188
+ kernel = 1
189
+ meshgrids = torch.meshgrid(
190
+ [
191
+ torch.arange(size, dtype=torch.float32)
192
+ for size in kernel_size
193
+ ]
194
+ )
195
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
196
+ mean = (size - 1) / 2
197
+ kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
198
+
199
+ # Make sure sum of values in gaussian kernel equals 1.
200
+ kernel = kernel / torch.sum(kernel)
201
+ # Reshape to depthwise convolutional weight
202
+ kernel = kernel.view(1, 1, *kernel.size())
203
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
204
+
205
+ self.register_buffer('weight', kernel)
206
+ self.groups = channels
207
+ self.scale = scale
208
+ inv_scale = 1 / scale
209
+ self.int_inv_scale = int(inv_scale)
210
+
211
+ def forward(self, input):
212
+ if self.scale == 1.0:
213
+ return input
214
+
215
+ out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
216
+ out = F.conv2d(out, weight=self.weight, groups=self.groups)
217
+ out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
218
+
219
+ return out
220
+
221
+
222
+ class ImagePyramide(torch.nn.Module):
223
+ """
224
+ Create image pyramide for computing pyramide perceptual loss.
225
+ """
226
+ def __init__(self, scales, num_channels):
227
+ super(ImagePyramide, self).__init__()
228
+ downs = {}
229
+ for scale in scales:
230
+ downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
231
+ self.downs = nn.ModuleDict(downs)
232
+
233
+ def forward(self, x):
234
+ out_dict = {}
235
+ for scale, down_module in self.downs.items():
236
+ out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
237
+ return out_dict
musetalk/models/syncnet.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is modified from LatentSync (https://github.com/bytedance/LatentSync/blob/main/latentsync/models/stable_syncnet.py).
3
+ """
4
+
5
+ import torch
6
+ from torch import nn
7
+ from einops import rearrange
8
+ from torch.nn import functional as F
9
+
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from diffusers.models.attention import Attention as CrossAttention, FeedForward
14
+ from diffusers.utils.import_utils import is_xformers_available
15
+ from einops import rearrange
16
+
17
+
18
+ class SyncNet(nn.Module):
19
+ def __init__(self, config):
20
+ super().__init__()
21
+ self.audio_encoder = DownEncoder2D(
22
+ in_channels=config["audio_encoder"]["in_channels"],
23
+ block_out_channels=config["audio_encoder"]["block_out_channels"],
24
+ downsample_factors=config["audio_encoder"]["downsample_factors"],
25
+ dropout=config["audio_encoder"]["dropout"],
26
+ attn_blocks=config["audio_encoder"]["attn_blocks"],
27
+ )
28
+
29
+ self.visual_encoder = DownEncoder2D(
30
+ in_channels=config["visual_encoder"]["in_channels"],
31
+ block_out_channels=config["visual_encoder"]["block_out_channels"],
32
+ downsample_factors=config["visual_encoder"]["downsample_factors"],
33
+ dropout=config["visual_encoder"]["dropout"],
34
+ attn_blocks=config["visual_encoder"]["attn_blocks"],
35
+ )
36
+
37
+ self.eval()
38
+
39
+ def forward(self, image_sequences, audio_sequences):
40
+ vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
41
+ audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
42
+
43
+ vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
44
+ audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
45
+
46
+ # Make them unit vectors
47
+ vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
48
+ audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
49
+
50
+ return vision_embeds, audio_embeds
51
+
52
+ def get_image_embed(self, image_sequences):
53
+ vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
54
+
55
+ vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
56
+
57
+ # Make them unit vectors
58
+ vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
59
+
60
+ return vision_embeds
61
+
62
+ def get_audio_embed(self, audio_sequences):
63
+ audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
64
+
65
+ audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
66
+
67
+ audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
68
+
69
+ return audio_embeds
70
+
71
+ class ResnetBlock2D(nn.Module):
72
+ def __init__(
73
+ self,
74
+ in_channels: int,
75
+ out_channels: int,
76
+ dropout: float = 0.0,
77
+ norm_num_groups: int = 32,
78
+ eps: float = 1e-6,
79
+ act_fn: str = "silu",
80
+ downsample_factor=2,
81
+ ):
82
+ super().__init__()
83
+
84
+ self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
85
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
86
+
87
+ self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True)
88
+ self.dropout = nn.Dropout(dropout)
89
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
90
+
91
+ if act_fn == "relu":
92
+ self.act_fn = nn.ReLU()
93
+ elif act_fn == "silu":
94
+ self.act_fn = nn.SiLU()
95
+
96
+ if in_channels != out_channels:
97
+ self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
98
+ else:
99
+ self.conv_shortcut = None
100
+
101
+ if isinstance(downsample_factor, list):
102
+ downsample_factor = tuple(downsample_factor)
103
+
104
+ if downsample_factor == 1:
105
+ self.downsample_conv = None
106
+ else:
107
+ self.downsample_conv = nn.Conv2d(
108
+ out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0
109
+ )
110
+ self.pad = (0, 1, 0, 1)
111
+ if isinstance(downsample_factor, tuple):
112
+ if downsample_factor[0] == 1:
113
+ self.pad = (0, 1, 1, 1) # The padding order is from back to front
114
+ elif downsample_factor[1] == 1:
115
+ self.pad = (1, 1, 0, 1)
116
+
117
+ def forward(self, input_tensor):
118
+ hidden_states = input_tensor
119
+
120
+ hidden_states = self.norm1(hidden_states)
121
+ hidden_states = self.act_fn(hidden_states)
122
+
123
+ hidden_states = self.conv1(hidden_states)
124
+ hidden_states = self.norm2(hidden_states)
125
+ hidden_states = self.act_fn(hidden_states)
126
+
127
+ hidden_states = self.dropout(hidden_states)
128
+ hidden_states = self.conv2(hidden_states)
129
+
130
+ if self.conv_shortcut is not None:
131
+ input_tensor = self.conv_shortcut(input_tensor)
132
+
133
+ hidden_states += input_tensor
134
+
135
+ if self.downsample_conv is not None:
136
+ hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0)
137
+ hidden_states = self.downsample_conv(hidden_states)
138
+
139
+ return hidden_states
140
+
141
+
142
+ class AttentionBlock2D(nn.Module):
143
+ def __init__(self, query_dim, norm_num_groups=32, dropout=0.0):
144
+ super().__init__()
145
+ if not is_xformers_available():
146
+ raise ModuleNotFoundError(
147
+ "You have to install xformers to enable memory efficient attetion", name="xformers"
148
+ )
149
+ # inner_dim = dim_head * heads
150
+ self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True)
151
+ self.norm2 = nn.LayerNorm(query_dim)
152
+ self.norm3 = nn.LayerNorm(query_dim)
153
+
154
+ self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu")
155
+
156
+ self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
157
+ self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
158
+
159
+ self.attn = CrossAttention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True)
160
+ self.attn._use_memory_efficient_attention_xformers = True
161
+
162
+ def forward(self, hidden_states):
163
+ assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}."
164
+
165
+ batch, channel, height, width = hidden_states.shape
166
+ residual = hidden_states
167
+
168
+ hidden_states = self.norm1(hidden_states)
169
+ hidden_states = self.conv_in(hidden_states)
170
+ hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
171
+
172
+ norm_hidden_states = self.norm2(hidden_states)
173
+ hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states
174
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
175
+
176
+ hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width)
177
+ hidden_states = self.conv_out(hidden_states)
178
+
179
+ hidden_states = hidden_states + residual
180
+ return hidden_states
181
+
182
+
183
+ class DownEncoder2D(nn.Module):
184
+ def __init__(
185
+ self,
186
+ in_channels=4 * 16,
187
+ block_out_channels=[64, 128, 256, 256],
188
+ downsample_factors=[2, 2, 2, 2],
189
+ layers_per_block=2,
190
+ norm_num_groups=32,
191
+ attn_blocks=[1, 1, 1, 1],
192
+ dropout: float = 0.0,
193
+ act_fn="silu",
194
+ ):
195
+ super().__init__()
196
+ self.layers_per_block = layers_per_block
197
+
198
+ # in
199
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
200
+
201
+ # down
202
+ self.down_blocks = nn.ModuleList([])
203
+
204
+ output_channels = block_out_channels[0]
205
+ for i, block_out_channel in enumerate(block_out_channels):
206
+ input_channels = output_channels
207
+ output_channels = block_out_channel
208
+ # is_final_block = i == len(block_out_channels) - 1
209
+
210
+ down_block = ResnetBlock2D(
211
+ in_channels=input_channels,
212
+ out_channels=output_channels,
213
+ downsample_factor=downsample_factors[i],
214
+ norm_num_groups=norm_num_groups,
215
+ dropout=dropout,
216
+ act_fn=act_fn,
217
+ )
218
+
219
+ self.down_blocks.append(down_block)
220
+
221
+ if attn_blocks[i] == 1:
222
+ attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout)
223
+ self.down_blocks.append(attention_block)
224
+
225
+ # out
226
+ self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
227
+ self.act_fn_out = nn.ReLU()
228
+
229
+ def forward(self, hidden_states):
230
+ hidden_states = self.conv_in(hidden_states)
231
+
232
+ # down
233
+ for down_block in self.down_blocks:
234
+ hidden_states = down_block(hidden_states)
235
+
236
+ # post-process
237
+ hidden_states = self.norm_out(hidden_states)
238
+ hidden_states = self.act_fn_out(hidden_states)
239
+
240
+ return hidden_states
musetalk/utils/training_utils.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.optim.lr_scheduler import CosineAnnealingLR
8
+ from diffusers import AutoencoderKL, UNet2DConditionModel
9
+ from transformers import WhisperModel
10
+ from diffusers.optimization import get_scheduler
11
+ from omegaconf import OmegaConf
12
+ from einops import rearrange
13
+
14
+ from musetalk.models.syncnet import SyncNet
15
+ from musetalk.loss.discriminator import MultiScaleDiscriminator, DiscriminatorFullModel
16
+ from musetalk.loss.basic_loss import Interpolate
17
+ import musetalk.loss.vgg_face as vgg_face
18
+ from musetalk.data.dataset import PortraitDataset
19
+ from musetalk.utils.utils import (
20
+ get_image_pred,
21
+ process_audio_features,
22
+ process_and_save_images
23
+ )
24
+
25
+ class Net(nn.Module):
26
+ def __init__(
27
+ self,
28
+ unet: UNet2DConditionModel,
29
+ ):
30
+ super().__init__()
31
+ self.unet = unet
32
+
33
+ def forward(
34
+ self,
35
+ input_latents,
36
+ timesteps,
37
+ audio_prompts,
38
+ ):
39
+ model_pred = self.unet(
40
+ input_latents,
41
+ timesteps,
42
+ encoder_hidden_states=audio_prompts
43
+ ).sample
44
+ return model_pred
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+ def initialize_models_and_optimizers(cfg, accelerator, weight_dtype):
49
+ """Initialize models and optimizers"""
50
+ model_dict = {
51
+ 'vae': None,
52
+ 'unet': None,
53
+ 'net': None,
54
+ 'wav2vec': None,
55
+ 'optimizer': None,
56
+ 'lr_scheduler': None,
57
+ 'scheduler_max_steps': None,
58
+ 'trainable_params': None
59
+ }
60
+
61
+ model_dict['vae'] = AutoencoderKL.from_pretrained(
62
+ cfg.pretrained_model_name_or_path,
63
+ subfolder=cfg.vae_type,
64
+ )
65
+
66
+ unet_config_file = os.path.join(
67
+ cfg.pretrained_model_name_or_path,
68
+ cfg.unet_sub_folder + "/musetalk.json"
69
+ )
70
+
71
+ with open(unet_config_file, 'r') as f:
72
+ unet_config = json.load(f)
73
+ model_dict['unet'] = UNet2DConditionModel(**unet_config)
74
+
75
+ if not cfg.random_init_unet:
76
+ pretrained_unet_path = os.path.join(cfg.pretrained_model_name_or_path, cfg.unet_sub_folder, "pytorch_model.bin")
77
+ print(f"### Loading existing unet weights from {pretrained_unet_path}. ###")
78
+ checkpoint = torch.load(pretrained_unet_path, map_location=accelerator.device)
79
+ model_dict['unet'].load_state_dict(checkpoint)
80
+
81
+ unet_params = [p.numel() for n, p in model_dict['unet'].named_parameters()]
82
+ logger.info(f"unet {sum(unet_params) / 1e6}M-parameter")
83
+
84
+ model_dict['vae'].requires_grad_(False)
85
+ model_dict['unet'].requires_grad_(True)
86
+
87
+ model_dict['vae'].to(accelerator.device, dtype=weight_dtype)
88
+
89
+ model_dict['net'] = Net(model_dict['unet'])
90
+
91
+ model_dict['wav2vec'] = WhisperModel.from_pretrained(cfg.whisper_path).to(
92
+ device="cuda", dtype=weight_dtype).eval()
93
+ model_dict['wav2vec'].requires_grad_(False)
94
+
95
+ if cfg.solver.gradient_checkpointing:
96
+ model_dict['unet'].enable_gradient_checkpointing()
97
+
98
+ if cfg.solver.scale_lr:
99
+ learning_rate = (
100
+ cfg.solver.learning_rate
101
+ * cfg.solver.gradient_accumulation_steps
102
+ * cfg.data.train_bs
103
+ * accelerator.num_processes
104
+ )
105
+ else:
106
+ learning_rate = cfg.solver.learning_rate
107
+
108
+ if cfg.solver.use_8bit_adam:
109
+ try:
110
+ import bitsandbytes as bnb
111
+ except ImportError:
112
+ raise ImportError(
113
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
114
+ )
115
+ optimizer_cls = bnb.optim.AdamW8bit
116
+ else:
117
+ optimizer_cls = torch.optim.AdamW
118
+
119
+ model_dict['trainable_params'] = list(filter(lambda p: p.requires_grad, model_dict['net'].parameters()))
120
+ if accelerator.is_main_process:
121
+ print('trainable params')
122
+ for n, p in model_dict['net'].named_parameters():
123
+ if p.requires_grad:
124
+ print(n)
125
+
126
+ model_dict['optimizer'] = optimizer_cls(
127
+ model_dict['trainable_params'],
128
+ lr=learning_rate,
129
+ betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
130
+ weight_decay=cfg.solver.adam_weight_decay,
131
+ eps=cfg.solver.adam_epsilon,
132
+ )
133
+
134
+ model_dict['scheduler_max_steps'] = cfg.solver.max_train_steps * cfg.solver.gradient_accumulation_steps
135
+ model_dict['lr_scheduler'] = get_scheduler(
136
+ cfg.solver.lr_scheduler,
137
+ optimizer=model_dict['optimizer'],
138
+ num_warmup_steps=cfg.solver.lr_warmup_steps * cfg.solver.gradient_accumulation_steps,
139
+ num_training_steps=model_dict['scheduler_max_steps'],
140
+ )
141
+
142
+ return model_dict
143
+
144
+ def initialize_dataloaders(cfg):
145
+ """Initialize training and validation dataloaders"""
146
+ dataloader_dict = {
147
+ 'train_dataset': None,
148
+ 'val_dataset': None,
149
+ 'train_dataloader': None,
150
+ 'val_dataloader': None
151
+ }
152
+
153
+ dataloader_dict['train_dataset'] = PortraitDataset(cfg={
154
+ 'image_size': cfg.data.image_size,
155
+ 'T': cfg.data.n_sample_frames,
156
+ "sample_method": cfg.data.sample_method,
157
+ 'top_k_ratio': cfg.data.top_k_ratio,
158
+ "contorl_face_min_size": cfg.data.contorl_face_min_size,
159
+ "dataset_key": cfg.data.dataset_key,
160
+ "padding_pixel_mouth": cfg.padding_pixel_mouth,
161
+ "whisper_path": cfg.whisper_path,
162
+ "min_face_size": cfg.data.min_face_size,
163
+ "cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
164
+ "cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
165
+ "crop_type": cfg.crop_type,
166
+ "random_margin_method": cfg.random_margin_method,
167
+ })
168
+
169
+ dataloader_dict['train_dataloader'] = torch.utils.data.DataLoader(
170
+ dataloader_dict['train_dataset'],
171
+ batch_size=cfg.data.train_bs,
172
+ shuffle=True,
173
+ num_workers=cfg.data.num_workers,
174
+ )
175
+
176
+ dataloader_dict['val_dataset'] = PortraitDataset(cfg={
177
+ 'image_size': cfg.data.image_size,
178
+ 'T': cfg.data.n_sample_frames,
179
+ "sample_method": cfg.data.sample_method,
180
+ 'top_k_ratio': cfg.data.top_k_ratio,
181
+ "contorl_face_min_size": cfg.data.contorl_face_min_size,
182
+ "dataset_key": cfg.data.dataset_key,
183
+ "padding_pixel_mouth": cfg.padding_pixel_mouth,
184
+ "whisper_path": cfg.whisper_path,
185
+ "min_face_size": cfg.data.min_face_size,
186
+ "cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
187
+ "cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
188
+ "crop_type": cfg.crop_type,
189
+ "random_margin_method": cfg.random_margin_method,
190
+ })
191
+
192
+ dataloader_dict['val_dataloader'] = torch.utils.data.DataLoader(
193
+ dataloader_dict['val_dataset'],
194
+ batch_size=cfg.data.train_bs,
195
+ shuffle=True,
196
+ num_workers=1,
197
+ )
198
+
199
+ return dataloader_dict
200
+
201
+ def initialize_loss_functions(cfg, accelerator, scheduler_max_steps):
202
+ """Initialize loss functions and discriminators"""
203
+ loss_dict = {
204
+ 'L1_loss': nn.L1Loss(reduction='mean'),
205
+ 'discriminator': None,
206
+ 'mouth_discriminator': None,
207
+ 'optimizer_D': None,
208
+ 'mouth_optimizer_D': None,
209
+ 'scheduler_D': None,
210
+ 'mouth_scheduler_D': None,
211
+ 'disc_scales': None,
212
+ 'discriminator_full': None,
213
+ 'mouth_discriminator_full': None
214
+ }
215
+
216
+ if cfg.loss_params.gan_loss > 0:
217
+ loss_dict['discriminator'] = MultiScaleDiscriminator(
218
+ **cfg.model_params.discriminator_params).to(accelerator.device)
219
+ loss_dict['discriminator_full'] = DiscriminatorFullModel(loss_dict['discriminator'])
220
+ loss_dict['disc_scales'] = cfg.model_params.discriminator_params.scales
221
+ loss_dict['optimizer_D'] = optim.AdamW(
222
+ loss_dict['discriminator'].parameters(),
223
+ lr=cfg.discriminator_train_params.lr,
224
+ weight_decay=cfg.discriminator_train_params.weight_decay,
225
+ betas=cfg.discriminator_train_params.betas,
226
+ eps=cfg.discriminator_train_params.eps)
227
+ loss_dict['scheduler_D'] = CosineAnnealingLR(
228
+ loss_dict['optimizer_D'],
229
+ T_max=scheduler_max_steps,
230
+ eta_min=1e-6
231
+ )
232
+
233
+ if cfg.loss_params.mouth_gan_loss > 0:
234
+ loss_dict['mouth_discriminator'] = MultiScaleDiscriminator(
235
+ **cfg.model_params.discriminator_params).to(accelerator.device)
236
+ loss_dict['mouth_discriminator_full'] = DiscriminatorFullModel(loss_dict['mouth_discriminator'])
237
+ loss_dict['mouth_optimizer_D'] = optim.AdamW(
238
+ loss_dict['mouth_discriminator'].parameters(),
239
+ lr=cfg.discriminator_train_params.lr,
240
+ weight_decay=cfg.discriminator_train_params.weight_decay,
241
+ betas=cfg.discriminator_train_params.betas,
242
+ eps=cfg.discriminator_train_params.eps)
243
+ loss_dict['mouth_scheduler_D'] = CosineAnnealingLR(
244
+ loss_dict['mouth_optimizer_D'],
245
+ T_max=scheduler_max_steps,
246
+ eta_min=1e-6
247
+ )
248
+
249
+ return loss_dict
250
+
251
+ def initialize_syncnet(cfg, accelerator, weight_dtype):
252
+ """Initialize SyncNet model"""
253
+ if cfg.loss_params.sync_loss > 0 or cfg.use_adapted_weight:
254
+ if cfg.data.n_sample_frames != 16:
255
+ raise ValueError(
256
+ f"Invalid n_sample_frames {cfg.data.n_sample_frames} for sync_loss, it should be 16."
257
+ )
258
+ syncnet_config = OmegaConf.load(cfg.syncnet_config_path)
259
+ syncnet = SyncNet(OmegaConf.to_container(
260
+ syncnet_config.model)).to(accelerator.device)
261
+ print(
262
+ f"Load SyncNet checkpoint from: {syncnet_config.ckpt.inference_ckpt_path}")
263
+ checkpoint = torch.load(
264
+ syncnet_config.ckpt.inference_ckpt_path, map_location=accelerator.device)
265
+ syncnet.load_state_dict(checkpoint["state_dict"])
266
+ syncnet.to(dtype=weight_dtype)
267
+ syncnet.requires_grad_(False)
268
+ syncnet.eval()
269
+ return syncnet
270
+ return None
271
+
272
+ def initialize_vgg(cfg, accelerator):
273
+ """Initialize VGG model"""
274
+ if cfg.loss_params.vgg_loss > 0:
275
+ vgg_IN = vgg_face.Vgg19().to(accelerator.device,)
276
+ pyramid = vgg_face.ImagePyramide(
277
+ cfg.loss_params.pyramid_scale, 3).to(accelerator.device)
278
+ vgg_IN.eval()
279
+ downsampler = Interpolate(
280
+ size=(224, 224), mode='bilinear', align_corners=False).to(accelerator.device)
281
+ return vgg_IN, pyramid, downsampler
282
+ return None, None, None
283
+
284
+ def validation(
285
+ cfg,
286
+ val_dataloader,
287
+ net,
288
+ vae,
289
+ wav2vec,
290
+ accelerator,
291
+ save_dir,
292
+ global_step,
293
+ weight_dtype,
294
+ syncnet_score=1,
295
+ ):
296
+ """Validation function for model evaluation"""
297
+ net.eval() # Set the model to evaluation mode
298
+ for batch in val_dataloader:
299
+ # The same ref_latents
300
+ ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to(
301
+ accelerator.device, non_blocking=True
302
+ )
303
+ pixel_values = batch["pixel_values_vid"].to(weight_dtype).to(
304
+ accelerator.device, non_blocking=True
305
+ )
306
+ bsz, num_frames, c, h, w = ref_pixel_values.shape
307
+
308
+ audio_prompts = process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype)
309
+ # audio feature for unet
310
+ audio_prompts = rearrange(
311
+ audio_prompts,
312
+ 'b f c h w-> (b f) c h w'
313
+ )
314
+ audio_prompts = rearrange(
315
+ audio_prompts,
316
+ '(b f) c h w -> (b f) (c h) w',
317
+ b=bsz
318
+ )
319
+ # different masked_latents
320
+ image_pred_train = get_image_pred(
321
+ pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
322
+ image_pred_infer = get_image_pred(
323
+ ref_pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
324
+
325
+ process_and_save_images(
326
+ batch,
327
+ image_pred_train,
328
+ image_pred_infer,
329
+ save_dir,
330
+ global_step,
331
+ accelerator,
332
+ cfg.num_images_to_keep,
333
+ syncnet_score
334
+ )
335
+ # only infer 1 image in validation
336
+ break
337
+ net.train() # Set the model back to training mode
musetalk/utils/utils.py CHANGED
@@ -2,6 +2,11 @@ import os
2
  import cv2
3
  import numpy as np
4
  import torch
 
 
 
 
 
5
 
6
  ffmpeg_path = os.getenv('FFMPEG_PATH')
7
  if ffmpeg_path is None:
@@ -11,7 +16,6 @@ elif ffmpeg_path not in os.getenv('PATH'):
11
  os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
12
 
13
 
14
- from musetalk.whisper.audio2feature import Audio2Feature
15
  from musetalk.models.vae import VAE
16
  from musetalk.models.unet import UNet,PositionalEncoding
17
 
@@ -76,3 +80,248 @@ def datagen(
76
  latent_batch = torch.cat(latent_batch, dim=0)
77
 
78
  yield whisper_batch.to(device), latent_batch.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import cv2
3
  import numpy as np
4
  import torch
5
+ from typing import Union, List
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ import shutil
9
+ import os.path as osp
10
 
11
  ffmpeg_path = os.getenv('FFMPEG_PATH')
12
  if ffmpeg_path is None:
 
16
  os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
17
 
18
 
 
19
  from musetalk.models.vae import VAE
20
  from musetalk.models.unet import UNet,PositionalEncoding
21
 
 
80
  latent_batch = torch.cat(latent_batch, dim=0)
81
 
82
  yield whisper_batch.to(device), latent_batch.to(device)
83
+
84
+ def cast_training_params(
85
+ model: Union[torch.nn.Module, List[torch.nn.Module]],
86
+ dtype=torch.float32,
87
+ ):
88
+ if not isinstance(model, list):
89
+ model = [model]
90
+ for m in model:
91
+ for param in m.parameters():
92
+ # only upcast trainable parameters into fp32
93
+ if param.requires_grad:
94
+ param.data = param.to(dtype)
95
+
96
+ def rand_log_normal(
97
+ shape,
98
+ loc=0.,
99
+ scale=1.,
100
+ device='cpu',
101
+ dtype=torch.float32,
102
+ generator=None
103
+ ):
104
+ """Draws samples from an lognormal distribution."""
105
+ rnd_normal = torch.randn(
106
+ shape, device=device, dtype=dtype, generator=generator) # N(0, I)
107
+ sigma = (rnd_normal * scale + loc).exp()
108
+ return sigma
109
+
110
+ def get_mouth_region(frames, image_pred, pixel_values_face_mask):
111
+ # Initialize lists to store the results for each image in the batch
112
+ mouth_real_list = []
113
+ mouth_generated_list = []
114
+
115
+ # Process each image in the batch
116
+ for b in range(frames.shape[0]):
117
+ # Find the non-zero area in the face mask
118
+ non_zero_indices = torch.nonzero(pixel_values_face_mask[b])
119
+ # If there are no non-zero indices, skip this image
120
+ if non_zero_indices.numel() == 0:
121
+ continue
122
+
123
+ min_y, max_y = torch.min(non_zero_indices[:, 1]), torch.max(
124
+ non_zero_indices[:, 1])
125
+ min_x, max_x = torch.min(non_zero_indices[:, 2]), torch.max(
126
+ non_zero_indices[:, 2])
127
+
128
+ # Crop the frames and image_pred according to the non-zero area
129
+ frames_cropped = frames[b, :, min_y:max_y, min_x:max_x]
130
+ image_pred_cropped = image_pred[b, :, min_y:max_y, min_x:max_x]
131
+ # Resize the cropped images to 256*256
132
+ frames_resized = F.interpolate(frames_cropped.unsqueeze(
133
+ 0), size=(256, 256), mode='bilinear', align_corners=False)
134
+ image_pred_resized = F.interpolate(image_pred_cropped.unsqueeze(
135
+ 0), size=(256, 256), mode='bilinear', align_corners=False)
136
+
137
+ # Append the resized images to the result lists
138
+ mouth_real_list.append(frames_resized)
139
+ mouth_generated_list.append(image_pred_resized)
140
+
141
+ # Convert the lists to tensors if they are not empty
142
+ mouth_real = torch.cat(mouth_real_list, dim=0) if mouth_real_list else None
143
+ mouth_generated = torch.cat(
144
+ mouth_generated_list, dim=0) if mouth_generated_list else None
145
+
146
+ return mouth_real, mouth_generated
147
+
148
+ def get_image_pred(pixel_values,
149
+ ref_pixel_values,
150
+ audio_prompts,
151
+ vae,
152
+ net,
153
+ weight_dtype):
154
+ with torch.no_grad():
155
+ bsz, num_frames, c, h, w = pixel_values.shape
156
+
157
+ masked_pixel_values = pixel_values.clone()
158
+ masked_pixel_values[:, :, :, h//2:, :] = -1
159
+
160
+ masked_frames = rearrange(
161
+ masked_pixel_values, 'b f c h w -> (b f) c h w')
162
+ masked_latents = vae.encode(masked_frames).latent_dist.mode()
163
+ masked_latents = masked_latents * vae.config.scaling_factor
164
+ masked_latents = masked_latents.float()
165
+
166
+ ref_frames = rearrange(ref_pixel_values, 'b f c h w-> (b f) c h w')
167
+ ref_latents = vae.encode(ref_frames).latent_dist.mode()
168
+ ref_latents = ref_latents * vae.config.scaling_factor
169
+ ref_latents = ref_latents.float()
170
+
171
+ input_latents = torch.cat([masked_latents, ref_latents], dim=1)
172
+ input_latents = input_latents.to(weight_dtype)
173
+ timesteps = torch.tensor([0], device=input_latents.device)
174
+ latents_pred = net(
175
+ input_latents,
176
+ timesteps,
177
+ audio_prompts,
178
+ )
179
+ latents_pred = (1 / vae.config.scaling_factor) * latents_pred
180
+ image_pred = vae.decode(latents_pred).sample
181
+ image_pred = image_pred.float()
182
+
183
+ return image_pred
184
+
185
+ def process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype):
186
+ with torch.no_grad():
187
+ audio_feature_length_per_frame = 2 * \
188
+ (cfg.data.audio_padding_length_left +
189
+ cfg.data.audio_padding_length_right + 1)
190
+ audio_feats = batch['audio_feature'].to(weight_dtype)
191
+ audio_feats = wav2vec.encoder(
192
+ audio_feats, output_hidden_states=True).hidden_states
193
+ audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype) # [B, T, 10, 5, 384]
194
+
195
+ start_ts = batch['audio_offset']
196
+ step_ts = batch['audio_step']
197
+ audio_feats = torch.cat([torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_left]),
198
+ audio_feats,
199
+ torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_right])], 1)
200
+ audio_prompts = []
201
+ for bb in range(bsz):
202
+ audio_feats_list = []
203
+ for f in range(num_frames):
204
+ cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
205
+ audio_clip = audio_feats[bb:bb+1,
206
+ cur_t: cur_t+audio_feature_length_per_frame]
207
+
208
+ audio_feats_list.append(audio_clip)
209
+ audio_feats_list = torch.stack(audio_feats_list, 1)
210
+ audio_prompts.append(audio_feats_list)
211
+ audio_prompts = torch.cat(audio_prompts) # B, T, 10, 5, 384
212
+ return audio_prompts
213
+
214
+ def save_checkpoint(model, save_dir, ckpt_num, name="appearance_net", total_limit=None, logger=None):
215
+ save_path = os.path.join(save_dir, f"{name}-{ckpt_num}.pth")
216
+
217
+ if total_limit is not None:
218
+ checkpoints = os.listdir(save_dir)
219
+ checkpoints = [d for d in checkpoints if d.endswith(".pth")]
220
+ checkpoints = [d for d in checkpoints if name in d]
221
+ checkpoints = sorted(
222
+ checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
223
+ )
224
+
225
+ if len(checkpoints) >= total_limit:
226
+ num_to_remove = len(checkpoints) - total_limit + 1
227
+ removing_checkpoints = checkpoints[0:num_to_remove]
228
+ logger.info(
229
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
230
+ )
231
+ logger.info(
232
+ f"removing checkpoints: {', '.join(removing_checkpoints)}")
233
+
234
+ for removing_checkpoint in removing_checkpoints:
235
+ removing_checkpoint = os.path.join(
236
+ save_dir, removing_checkpoint)
237
+ os.remove(removing_checkpoint)
238
+
239
+ state_dict = model.state_dict()
240
+ torch.save(state_dict, save_path)
241
+
242
+ def save_models(accelerator, net, save_dir, global_step, cfg, logger=None):
243
+ unwarp_net = accelerator.unwrap_model(net)
244
+ save_checkpoint(
245
+ unwarp_net.unet,
246
+ save_dir,
247
+ global_step,
248
+ name="unet",
249
+ total_limit=cfg.total_limit,
250
+ logger=logger
251
+ )
252
+
253
+ def delete_additional_ckpt(base_path, num_keep):
254
+ dirs = []
255
+ for d in os.listdir(base_path):
256
+ if d.startswith("checkpoint-"):
257
+ dirs.append(d)
258
+ num_tot = len(dirs)
259
+ if num_tot <= num_keep:
260
+ return
261
+ # ensure ckpt is sorted and delete the ealier!
262
+ del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
263
+ for d in del_dirs:
264
+ path_to_dir = osp.join(base_path, d)
265
+ if osp.exists(path_to_dir):
266
+ shutil.rmtree(path_to_dir)
267
+
268
+ def seed_everything(seed):
269
+ import random
270
+
271
+ import numpy as np
272
+
273
+ torch.manual_seed(seed)
274
+ torch.cuda.manual_seed_all(seed)
275
+ np.random.seed(seed % (2**32))
276
+ random.seed(seed)
277
+
278
+ def process_and_save_images(
279
+ batch,
280
+ image_pred,
281
+ image_pred_infer,
282
+ save_dir,
283
+ global_step,
284
+ accelerator,
285
+ num_images_to_keep=10,
286
+ syncnet_score=1
287
+ ):
288
+ # Rearrange the tensors
289
+ print("image_pred.shape: ", image_pred.shape)
290
+ pixel_values_ref_img = rearrange(batch['pixel_values_ref_img'], "b f c h w -> (b f) c h w")
291
+ pixel_values = rearrange(batch["pixel_values_vid"], 'b f c h w -> (b f) c h w')
292
+
293
+ # Create masked pixel values
294
+ masked_pixel_values = batch["pixel_values_vid"].clone()
295
+ _, _, _, h, _ = batch["pixel_values_vid"].shape
296
+ masked_pixel_values[:, :, :, h//2:, :] = -1
297
+ masked_pixel_values = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w')
298
+
299
+ # Keep only the specified number of images
300
+ pixel_values = pixel_values[:num_images_to_keep, :, :, :]
301
+ masked_pixel_values = masked_pixel_values[:num_images_to_keep, :, :, :]
302
+ pixel_values_ref_img = pixel_values_ref_img[:num_images_to_keep, :, :, :]
303
+ image_pred = image_pred.detach()[:num_images_to_keep, :, :, :]
304
+ image_pred_infer = image_pred_infer.detach()[:num_images_to_keep, :, :, :]
305
+
306
+ # Concatenate images
307
+ concat = torch.cat([
308
+ masked_pixel_values * 0.5 + 0.5,
309
+ pixel_values_ref_img * 0.5 + 0.5,
310
+ image_pred * 0.5 + 0.5,
311
+ pixel_values * 0.5 + 0.5,
312
+ image_pred_infer * 0.5 + 0.5,
313
+ ], dim=2)
314
+ print("concat.shape: ", concat.shape)
315
+
316
+ # Create the save directory if it doesn't exist
317
+ os.makedirs(f'{save_dir}/samples/', exist_ok=True)
318
+
319
+ # Try to save the concatenated image
320
+ try:
321
+ # Concatenate images horizontally and convert to numpy array
322
+ final_image = torch.cat([concat[i] for i in range(concat.shape[0])], dim=-1).permute(1, 2, 0).cpu().numpy()[:, :, [2, 1, 0]] * 255
323
+ # Save the image
324
+ cv2.imwrite(f'{save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg', final_image)
325
+ print(f"Image saved successfully: {save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg")
326
+ except Exception as e:
327
+ print(f"Failed to save image: {e}")
scripts/preprocess.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import subprocess
4
+ from omegaconf import OmegaConf
5
+ from typing import Tuple, List, Union
6
+ import decord
7
+ import json
8
+ import cv2
9
+ from musetalk.utils.face_detection import FaceAlignment,LandmarksType
10
+ from mmpose.apis import inference_topdown, init_model
11
+ from mmpose.structures import merge_data_samples
12
+ import torch
13
+ import numpy as np
14
+ from tqdm import tqdm
15
+
16
+ ffmpeg_path = "./ffmpeg-4.4-amd64-static/"
17
+ if ffmpeg_path not in os.getenv('PATH'):
18
+ print("add ffmpeg to path")
19
+ os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
20
+
21
+ class AnalyzeFace:
22
+ def __init__(self, device: Union[str, torch.device], config_file: str, checkpoint_file: str):
23
+ """
24
+ Initialize the AnalyzeFace class with the given device, config file, and checkpoint file.
25
+
26
+ Parameters:
27
+ device (Union[str, torch.device]): The device to run the models on ('cuda' or 'cpu').
28
+ config_file (str): Path to the mmpose model configuration file.
29
+ checkpoint_file (str): Path to the mmpose model checkpoint file.
30
+ """
31
+ self.device = device
32
+ self.dwpose = init_model(config_file, checkpoint_file, device=self.device)
33
+ self.facedet = FaceAlignment(LandmarksType._2D, flip_input=False, device=self.device)
34
+
35
+ def __call__(self, im: np.ndarray) -> Tuple[List[np.ndarray], np.ndarray]:
36
+ """
37
+ Detect faces and keypoints in the given image.
38
+
39
+ Parameters:
40
+ im (np.ndarray): The input image.
41
+ maxface (bool): Whether to detect the maximum face. Default is True.
42
+
43
+ Returns:
44
+ Tuple[List[np.ndarray], np.ndarray]: A tuple containing the bounding boxes and keypoints.
45
+ """
46
+ try:
47
+ # Ensure the input image has the correct shape
48
+ if im.ndim == 3:
49
+ im = np.expand_dims(im, axis=0)
50
+ elif im.ndim != 4 or im.shape[0] != 1:
51
+ raise ValueError("Input image must have shape (1, H, W, C)")
52
+
53
+ bbox = self.facedet.get_detections_for_batch(np.asarray(im))
54
+ results = inference_topdown(self.dwpose, np.asarray(im)[0])
55
+ results = merge_data_samples(results)
56
+ keypoints = results.pred_instances.keypoints
57
+ face_land_mark= keypoints[0][23:91]
58
+ face_land_mark = face_land_mark.astype(np.int32)
59
+
60
+ return face_land_mark, bbox
61
+
62
+ except Exception as e:
63
+ print(f"Error during face analysis: {e}")
64
+ return np.array([]),[]
65
+
66
+ def convert_video(org_path: str, dst_path: str, vid_list: List[str]) -> None:
67
+
68
+ """
69
+ Convert video files to a specified format and save them to the destination path.
70
+
71
+ Parameters:
72
+ org_path (str): The directory containing the original video files.
73
+ dst_path (str): The directory where the converted video files will be saved.
74
+ vid_list (List[str]): A list of video file names to process.
75
+
76
+ Returns:
77
+ None
78
+ """
79
+ for idx, vid in enumerate(vid_list):
80
+ if vid.endswith('.mp4'):
81
+ org_vid_path = os.path.join(org_path, vid)
82
+ dst_vid_path = os.path.join(dst_path, vid)
83
+
84
+ if org_vid_path != dst_vid_path:
85
+ cmd = [
86
+ "ffmpeg", "-hide_banner", "-y", "-i", org_vid_path,
87
+ "-r", "25", "-crf", "15", "-c:v", "libx264",
88
+ "-pix_fmt", "yuv420p", dst_vid_path
89
+ ]
90
+ subprocess.run(cmd, check=True)
91
+
92
+ if idx % 1000 == 0:
93
+ print(f"### {idx} videos converted ###")
94
+
95
+ def segment_video(org_path: str, dst_path: str, vid_list: List[str], segment_duration: int = 30) -> None:
96
+ """
97
+ Segment video files into smaller clips of specified duration.
98
+
99
+ Parameters:
100
+ org_path (str): The directory containing the original video files.
101
+ dst_path (str): The directory where the segmented video files will be saved.
102
+ vid_list (List[str]): A list of video file names to process.
103
+ segment_duration (int): The duration of each segment in seconds. Default is 30 seconds.
104
+
105
+ Returns:
106
+ None
107
+ """
108
+ for idx, vid in enumerate(vid_list):
109
+ if vid.endswith('.mp4'):
110
+ input_file = os.path.join(org_path, vid)
111
+ original_filename = os.path.basename(input_file)
112
+
113
+ command = [
114
+ 'ffmpeg', '-i', input_file, '-c', 'copy', '-map', '0',
115
+ '-segment_time', str(segment_duration), '-f', 'segment',
116
+ '-reset_timestamps', '1',
117
+ os.path.join(dst_path, f'clip%03d_{original_filename}')
118
+ ]
119
+
120
+ subprocess.run(command, check=True)
121
+
122
+ def extract_audio(org_path: str, dst_path: str, vid_list: List[str]) -> None:
123
+ """
124
+ Extract audio from video files and save as WAV format.
125
+
126
+ Parameters:
127
+ org_path (str): The directory containing the original video files.
128
+ dst_path (str): The directory where the extracted audio files will be saved.
129
+ vid_list (List[str]): A list of video file names to process.
130
+
131
+ Returns:
132
+ None
133
+ """
134
+ for idx, vid in enumerate(vid_list):
135
+ if vid.endswith('.mp4'):
136
+ video_path = os.path.join(org_path, vid)
137
+ audio_output_path = os.path.join(dst_path, os.path.splitext(vid)[0] + ".wav")
138
+ try:
139
+ command = [
140
+ 'ffmpeg', '-hide_banner', '-y', '-i', video_path,
141
+ '-vn', '-acodec', 'pcm_s16le', '-f', 'wav',
142
+ '-ar', '16000', '-ac', '1', audio_output_path,
143
+ ]
144
+
145
+ subprocess.run(command, check=True)
146
+ print(f"Audio saved to: {audio_output_path}")
147
+ except subprocess.CalledProcessError as e:
148
+ print(f"Error extracting audio from {vid}: {e}")
149
+
150
+ def split_data(video_files: List[str], val_list_hdtf: List[str]) -> (List[str], List[str]):
151
+ """
152
+ Split video files into training and validation sets based on val_list_hdtf.
153
+
154
+ Parameters:
155
+ video_files (List[str]): A list of video file names.
156
+ val_list_hdtf (List[str]): A list of validation file identifiers.
157
+
158
+ Returns:
159
+ (List[str], List[str]): A tuple containing the training and validation file lists.
160
+ """
161
+ val_files = [f for f in video_files if any(val_id in f for val_id in val_list_hdtf)]
162
+ train_files = [f for f in video_files if f not in val_files]
163
+ return train_files, val_files
164
+
165
+ def save_list_to_file(file_path: str, data_list: List[str]) -> None:
166
+ """
167
+ Save a list of strings to a file, each string on a new line.
168
+
169
+ Parameters:
170
+ file_path (str): The path to the file where the list will be saved.
171
+ data_list (List[str]): The list of strings to save.
172
+
173
+ Returns:
174
+ None
175
+ """
176
+ with open(file_path, 'w') as file:
177
+ for item in data_list:
178
+ file.write(f"{item}\n")
179
+
180
+ def generate_train_list(cfg):
181
+ train_file_path = cfg.video_clip_file_list_train
182
+ val_file_path = cfg.video_clip_file_list_val
183
+ val_list_hdtf = cfg.val_list_hdtf
184
+
185
+ meta_list = os.listdir(cfg.meta_root)
186
+
187
+ sorted_meta_list = sorted(meta_list)
188
+ train_files, val_files = split_data(meta_list, val_list_hdtf)
189
+
190
+ save_list_to_file(train_file_path, train_files)
191
+ save_list_to_file(val_file_path, val_files)
192
+
193
+ print(val_list_hdtf)
194
+
195
+ def analyze_video(org_path: str, dst_path: str, vid_list: List[str]) -> None:
196
+ """
197
+ Convert video files to a specified format and save them to the destination path.
198
+
199
+ Parameters:
200
+ org_path (str): The directory containing the original video files.
201
+ dst_path (str): The directory where the meta json will be saved.
202
+ vid_list (List[str]): A list of video file names to process.
203
+
204
+ Returns:
205
+ None
206
+ """
207
+ device = "cuda" if torch.cuda.is_available() else "cpu"
208
+ config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
209
+ checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth'
210
+
211
+ analyze_face = AnalyzeFace(device, config_file, checkpoint_file)
212
+
213
+ for vid in tqdm(vid_list, desc="Processing videos"):
214
+ #vid = "clip005_WDA_BernieSanders_000.mp4"
215
+ #print(vid)
216
+ if vid.endswith('.mp4'):
217
+ vid_path = os.path.join(org_path, vid)
218
+ wav_path = vid_path.replace(".mp4",".wav")
219
+ vid_meta = os.path.join(dst_path, os.path.splitext(vid)[0] + ".json")
220
+ if os.path.exists(vid_meta):
221
+ continue
222
+ print('process video {}'.format(vid))
223
+
224
+ total_bbox_list = []
225
+ total_pts_list = []
226
+ isvalid = True
227
+
228
+ # process
229
+ try:
230
+ cap = decord.VideoReader(vid_path, fault_tol=1)
231
+ except Exception as e:
232
+ print(e)
233
+ continue
234
+
235
+ total_frames = len(cap)
236
+ for frame_idx in range(total_frames):
237
+ frame = cap[frame_idx]
238
+ if frame_idx==0:
239
+ video_height,video_width,_ = frame.shape
240
+ frame_bgr = cv2.cvtColor(frame.asnumpy(), cv2.COLOR_BGR2RGB)
241
+ pts_list, bbox_list = analyze_face(frame_bgr)
242
+
243
+ if len(bbox_list)>0 and None not in bbox_list:
244
+ bbox = bbox_list[0]
245
+ else:
246
+ isvalid = False
247
+ bbox = []
248
+ print(f"set isvalid to False as broken img in {frame_idx} of {vid}")
249
+ break
250
+
251
+ #print(pts_list)
252
+ if len(pts_list)>0 and pts_list is not None:
253
+ pts = pts_list.tolist()
254
+ else:
255
+ isvalid = False
256
+ pts = []
257
+ break
258
+
259
+ if frame_idx==0:
260
+ x1,y1,x2,y2 = bbox
261
+ face_height, face_width = y2-y1,x2-x1
262
+
263
+ total_pts_list.append(pts)
264
+ total_bbox_list.append(bbox)
265
+
266
+ meta_data = {
267
+ "mp4_path": vid_path,
268
+ "wav_path": wav_path,
269
+ "video_size": [video_height, video_width],
270
+ "face_size": [face_height, face_width],
271
+ "frames": total_frames,
272
+ "face_list": total_bbox_list,
273
+ "landmark_list": total_pts_list,
274
+ "isvalid":isvalid,
275
+ }
276
+ with open(vid_meta, 'w') as f:
277
+ json.dump(meta_data, f, indent=4)
278
+
279
+
280
+
281
+ def main(cfg):
282
+ # Ensure all necessary directories exist
283
+ os.makedirs(cfg.video_root_25fps, exist_ok=True)
284
+ os.makedirs(cfg.video_audio_clip_root, exist_ok=True)
285
+ os.makedirs(cfg.meta_root, exist_ok=True)
286
+ os.makedirs(os.path.dirname(cfg.video_file_list), exist_ok=True)
287
+ os.makedirs(os.path.dirname(cfg.video_clip_file_list_train), exist_ok=True)
288
+ os.makedirs(os.path.dirname(cfg.video_clip_file_list_val), exist_ok=True)
289
+
290
+ vid_list = os.listdir(cfg.video_root_raw)
291
+ sorted_vid_list = sorted(vid_list)
292
+
293
+ # Save video file list
294
+ with open(cfg.video_file_list, 'w') as file:
295
+ for vid in sorted_vid_list:
296
+ file.write(vid + '\n')
297
+
298
+ # 1. Convert videos to 25 FPS
299
+ convert_video(cfg.video_root_raw, cfg.video_root_25fps, sorted_vid_list)
300
+
301
+ # 2. Segment videos into 30-second clips
302
+ segment_video(cfg.video_root_25fps, cfg.video_audio_clip_root, vid_list, segment_duration=cfg.clip_len_second)
303
+
304
+ # 3. Extract audio
305
+ clip_vid_list = os.listdir(cfg.video_audio_clip_root)
306
+ extract_audio(cfg.video_audio_clip_root, cfg.video_audio_clip_root, clip_vid_list)
307
+
308
+ # 4. Generate video metadata
309
+ analyze_video(cfg.video_audio_clip_root, cfg.meta_root, clip_vid_list)
310
+
311
+ # 5. Generate training and validation set lists
312
+ generate_train_list(cfg)
313
+ print("done")
314
+
315
+ if __name__ == "__main__":
316
+ parser = argparse.ArgumentParser()
317
+ parser.add_argument("--config", type=str, default="./configs/training/preprocess.yaml")
318
+ args = parser.parse_args()
319
+ config = OmegaConf.load(args.config)
320
+
321
+ main(config)
322
+
train.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import diffusers
3
+ import logging
4
+ import math
5
+ import os
6
+ import time
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+ import transformers
12
+ import warnings
13
+ import random
14
+
15
+ from accelerate import Accelerator
16
+ from accelerate.utils import LoggerType
17
+ from accelerate import InitProcessGroupKwargs
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import DistributedDataParallelKwargs
20
+ from datetime import datetime
21
+ from datetime import timedelta
22
+
23
+ from diffusers.utils import check_min_version
24
+ from einops import rearrange
25
+ from omegaconf import OmegaConf
26
+ from tqdm.auto import tqdm
27
+
28
+ from musetalk.utils.utils import (
29
+ delete_additional_ckpt,
30
+ seed_everything,
31
+ get_mouth_region,
32
+ process_audio_features,
33
+ save_models
34
+ )
35
+ from musetalk.loss.basic_loss import set_requires_grad
36
+ from musetalk.loss.syncnet import get_sync_loss
37
+ from musetalk.utils.training_utils import (
38
+ initialize_models_and_optimizers,
39
+ initialize_dataloaders,
40
+ initialize_loss_functions,
41
+ initialize_syncnet,
42
+ initialize_vgg,
43
+ validation
44
+ )
45
+
46
+ logger = get_logger(__name__, log_level="INFO")
47
+ warnings.filterwarnings("ignore")
48
+ check_min_version("0.10.0.dev0")
49
+
50
+ def main(cfg):
51
+ exp_name = cfg.exp_name
52
+ save_dir = f"{cfg.output_dir}/{exp_name}"
53
+ os.makedirs(save_dir, exist_ok=True)
54
+
55
+ kwargs = DistributedDataParallelKwargs()
56
+ process_group_kwargs = InitProcessGroupKwargs(
57
+ timeout=timedelta(seconds=5400))
58
+ accelerator = Accelerator(
59
+ gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
60
+ log_with=["tensorboard", LoggerType.TENSORBOARD],
61
+ project_dir=os.path.join(save_dir, "./tensorboard"),
62
+ kwargs_handlers=[kwargs, process_group_kwargs],
63
+ )
64
+
65
+ # Make one log on every process with the configuration for debugging.
66
+ logging.basicConfig(
67
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
68
+ datefmt="%m/%d/%Y %H:%M:%S",
69
+ level=logging.INFO,
70
+ )
71
+ logger.info(accelerator.state, main_process_only=False)
72
+ if accelerator.is_local_main_process:
73
+ transformers.utils.logging.set_verbosity_warning()
74
+ diffusers.utils.logging.set_verbosity_info()
75
+ else:
76
+ transformers.utils.logging.set_verbosity_error()
77
+ diffusers.utils.logging.set_verbosity_error()
78
+
79
+ # If passed along, set the training seed now.
80
+ if cfg.seed is not None:
81
+ print('cfg.seed', cfg.seed, accelerator.process_index)
82
+ seed_everything(cfg.seed + accelerator.process_index)
83
+
84
+ weight_dtype = torch.float32
85
+
86
+ model_dict = initialize_models_and_optimizers(cfg, accelerator, weight_dtype)
87
+ dataloader_dict = initialize_dataloaders(cfg)
88
+ loss_dict = initialize_loss_functions(cfg, accelerator, model_dict['scheduler_max_steps'])
89
+ syncnet = initialize_syncnet(cfg, accelerator, weight_dtype)
90
+ vgg_IN, pyramid, downsampler = initialize_vgg(cfg, accelerator)
91
+
92
+ # Prepare everything with our `accelerator`.
93
+ model_dict['net'], model_dict['optimizer'], model_dict['lr_scheduler'], dataloader_dict['train_dataloader'], dataloader_dict['val_dataloader'] = accelerator.prepare(
94
+ model_dict['net'], model_dict['optimizer'], model_dict['lr_scheduler'], dataloader_dict['train_dataloader'], dataloader_dict['val_dataloader']
95
+ )
96
+ print("length train/val", len(dataloader_dict['train_dataloader']), len(dataloader_dict['val_dataloader']))
97
+
98
+ # Calculate training steps and epochs
99
+ num_update_steps_per_epoch = math.ceil(
100
+ len(dataloader_dict['train_dataloader']) / cfg.solver.gradient_accumulation_steps
101
+ )
102
+ num_train_epochs = math.ceil(
103
+ cfg.solver.max_train_steps / num_update_steps_per_epoch
104
+ )
105
+
106
+ # Initialize trackers on the main process
107
+ if accelerator.is_main_process:
108
+ run_time = datetime.now().strftime("%Y%m%d-%H%M")
109
+ accelerator.init_trackers(
110
+ cfg.exp_name,
111
+ init_kwargs={"mlflow": {"run_name": run_time}},
112
+ )
113
+
114
+ # Calculate total batch size
115
+ total_batch_size = (
116
+ cfg.data.train_bs
117
+ * accelerator.num_processes
118
+ * cfg.solver.gradient_accumulation_steps
119
+ )
120
+
121
+ # Log training information
122
+ logger.info("***** Running training *****")
123
+ logger.info(f"Num Epochs = {num_train_epochs}")
124
+ logger.info(f"Instantaneous batch size per device = {cfg.data.train_bs}")
125
+ logger.info(
126
+ f"Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
127
+ )
128
+ logger.info(
129
+ f"Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}")
130
+ logger.info(f"Total optimization steps = {cfg.solver.max_train_steps}")
131
+
132
+ global_step = 0
133
+ first_epoch = 0
134
+
135
+ # Load checkpoint if resuming training
136
+ if cfg.resume_from_checkpoint:
137
+ resume_dir = save_dir
138
+ dirs = os.listdir(resume_dir)
139
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
140
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
141
+ if len(dirs) > 0:
142
+ path = dirs[-1]
143
+ accelerator.load_state(os.path.join(resume_dir, path))
144
+ accelerator.print(f"Resuming from checkpoint {path}")
145
+ global_step = int(path.split("-")[1])
146
+ first_epoch = global_step // num_update_steps_per_epoch
147
+ resume_step = global_step % num_update_steps_per_epoch
148
+
149
+ # Initialize progress bar
150
+ progress_bar = tqdm(
151
+ range(global_step, cfg.solver.max_train_steps),
152
+ disable=not accelerator.is_local_main_process,
153
+ )
154
+ progress_bar.set_description("Steps")
155
+
156
+ # Log model types
157
+ print("log type of models")
158
+ print("unet", model_dict['unet'].dtype)
159
+ print("vae", model_dict['vae'].dtype)
160
+ print("wav2vec", model_dict['wav2vec'].dtype)
161
+
162
+ def get_ganloss_weight(step):
163
+ """Calculate GAN loss weight based on training step"""
164
+ if step < cfg.discriminator_train_params.start_gan:
165
+ return 0.0
166
+ else:
167
+ return 1.0
168
+
169
+ # Training loop
170
+ for epoch in range(first_epoch, num_train_epochs):
171
+ # Set models to training mode
172
+ model_dict['unet'].train()
173
+ if cfg.loss_params.gan_loss > 0:
174
+ loss_dict['discriminator'].train()
175
+ if cfg.loss_params.mouth_gan_loss > 0:
176
+ loss_dict['mouth_discriminator'].train()
177
+
178
+ # Initialize loss accumulators
179
+ train_loss = 0.0
180
+ train_loss_D = 0.0
181
+ train_loss_D_mouth = 0.0
182
+ l1_loss_accum = 0.0
183
+ vgg_loss_accum = 0.0
184
+ gan_loss_accum = 0.0
185
+ gan_loss_accum_mouth = 0.0
186
+ fm_loss_accum = 0.0
187
+ sync_loss_accum = 0.0
188
+ adapted_weight_accum = 0.0
189
+
190
+ t_data_start = time.time()
191
+ for step, batch in enumerate(dataloader_dict['train_dataloader']):
192
+ t_data = time.time() - t_data_start
193
+ t_model_start = time.time()
194
+
195
+ with torch.no_grad():
196
+ # Process input data
197
+ pixel_values = batch["pixel_values_vid"].to(weight_dtype).to(
198
+ accelerator.device,
199
+ non_blocking=True
200
+ )
201
+ bsz, num_frames, c, h, w = pixel_values.shape
202
+
203
+ # Process reference images
204
+ ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to(
205
+ accelerator.device,
206
+ non_blocking=True
207
+ )
208
+
209
+ # Get face mask for GAN
210
+ pixel_values_face_mask = batch['pixel_values_face_mask']
211
+
212
+ # Process audio features
213
+ audio_prompts = process_audio_features(cfg, batch, model_dict['wav2vec'], bsz, num_frames, weight_dtype)
214
+
215
+ # Initialize adapted weight
216
+ adapted_weight = 1
217
+
218
+ # Process sync loss if enabled
219
+ if cfg.loss_params.sync_loss > 0:
220
+ mels = batch['mel']
221
+ # Prepare frames for latentsync (combine channels and frames)
222
+ gt_frames = rearrange(pixel_values, 'b f c h w-> b (f c) h w')
223
+ # Use lower half of face for latentsync
224
+ height = gt_frames.shape[2]
225
+ gt_frames = gt_frames[:, :, height // 2:, :]
226
+
227
+ # Get audio embeddings
228
+ audio_embed = syncnet.get_audio_embed(mels)
229
+
230
+ # Calculate adapted weight based on audio-visual similarity
231
+ if cfg.use_adapted_weight:
232
+ vision_embed_gt = syncnet.get_vision_embed(gt_frames)
233
+ image_audio_sim_gt = F.cosine_similarity(
234
+ audio_embed,
235
+ vision_embed_gt,
236
+ dim=1
237
+ )[0]
238
+
239
+ if image_audio_sim_gt < 0.05 or image_audio_sim_gt > 0.65:
240
+ if cfg.adapted_weight_type == "cut_off":
241
+ adapted_weight = 0.0 # Skip this batch
242
+ print(
243
+ f"\nThe i-a similarity in step {global_step} is {image_audio_sim_gt}, set adapted_weight to {adapted_weight}.")
244
+ elif cfg.adapted_weight_type == "linear":
245
+ adapted_weight = image_audio_sim_gt
246
+ else:
247
+ print(f"unknown adapted_weight_type: {cfg.adapted_weight_type}")
248
+ adapted_weight = 1
249
+
250
+ # Random frame selection for memory efficiency
251
+ max_start = 16 - cfg.num_backward_frames
252
+ frames_left_index = random.randint(0, max_start) if max_start > 0 else 0
253
+ frames_right_index = frames_left_index + cfg.num_backward_frames
254
+ else:
255
+ frames_left_index = 0
256
+ frames_right_index = cfg.data.n_sample_frames
257
+
258
+ # Extract frames for backward pass
259
+ pixel_values_backward = pixel_values[:, frames_left_index:frames_right_index, ...]
260
+ ref_pixel_values_backward = ref_pixel_values[:, frames_left_index:frames_right_index, ...]
261
+ pixel_values_face_mask_backward = pixel_values_face_mask[:, frames_left_index:frames_right_index, ...]
262
+ audio_prompts_backward = audio_prompts[:, frames_left_index:frames_right_index, ...]
263
+
264
+ # Encode target images
265
+ frames = rearrange(pixel_values_backward, 'b f c h w-> (b f) c h w')
266
+ latents = model_dict['vae'].encode(frames).latent_dist.mode()
267
+ latents = latents * model_dict['vae'].config.scaling_factor
268
+ latents = latents.float()
269
+
270
+ # Create masked images
271
+ masked_pixel_values = pixel_values_backward.clone()
272
+ masked_pixel_values[:, :, :, h//2:, :] = -1
273
+ masked_frames = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w')
274
+ masked_latents = model_dict['vae'].encode(masked_frames).latent_dist.mode()
275
+ masked_latents = masked_latents * model_dict['vae'].config.scaling_factor
276
+ masked_latents = masked_latents.float()
277
+
278
+ # Encode reference images
279
+ ref_frames = rearrange(ref_pixel_values_backward, 'b f c h w-> (b f) c h w')
280
+ ref_latents = model_dict['vae'].encode(ref_frames).latent_dist.mode()
281
+ ref_latents = ref_latents * model_dict['vae'].config.scaling_factor
282
+ ref_latents = ref_latents.float()
283
+
284
+ # Prepare face mask and audio features
285
+ pixel_values_face_mask_backward = rearrange(
286
+ pixel_values_face_mask_backward,
287
+ "b f c h w -> (b f) c h w"
288
+ )
289
+ audio_prompts_backward = rearrange(
290
+ audio_prompts_backward,
291
+ 'b f c h w-> (b f) c h w'
292
+ )
293
+ audio_prompts_backward = rearrange(
294
+ audio_prompts_backward,
295
+ '(b f) c h w -> (b f) (c h) w',
296
+ b=bsz
297
+ )
298
+
299
+ # Apply reference dropout (currently inactive)
300
+ dropout = nn.Dropout(p=cfg.ref_dropout_rate)
301
+ ref_latents = dropout(ref_latents)
302
+
303
+ # Prepare model inputs
304
+ input_latents = torch.cat([masked_latents, ref_latents], dim=1)
305
+ input_latents = input_latents.to(weight_dtype)
306
+ timesteps = torch.tensor([0], device=input_latents.device)
307
+
308
+ # Forward pass
309
+ latents_pred = model_dict['net'](
310
+ input_latents,
311
+ timesteps,
312
+ audio_prompts_backward,
313
+ )
314
+ latents_pred = (1 / model_dict['vae'].config.scaling_factor) * latents_pred
315
+ image_pred = model_dict['vae'].decode(latents_pred).sample
316
+
317
+ # Convert to float
318
+ image_pred = image_pred.float()
319
+ frames = frames.float()
320
+
321
+ # Calculate L1 loss
322
+ l1_loss = loss_dict['L1_loss'](frames, image_pred)
323
+ l1_loss_accum += l1_loss.item()
324
+ loss = cfg.loss_params.l1_loss * l1_loss * adapted_weight
325
+
326
+ # Process mouth GAN loss if enabled
327
+ if cfg.loss_params.mouth_gan_loss > 0:
328
+ frames_mouth, image_pred_mouth = get_mouth_region(
329
+ frames,
330
+ image_pred,
331
+ pixel_values_face_mask_backward
332
+ )
333
+ pyramide_real_mouth = pyramid(downsampler(frames_mouth))
334
+ pyramide_generated_mouth = pyramid(downsampler(image_pred_mouth))
335
+
336
+ # Process VGG loss if enabled
337
+ if cfg.loss_params.vgg_loss > 0:
338
+ pyramide_real = pyramid(downsampler(frames))
339
+ pyramide_generated = pyramid(downsampler(image_pred))
340
+
341
+ loss_IN = 0
342
+ for scale in cfg.loss_params.pyramid_scale:
343
+ x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)])
344
+ y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)])
345
+ for i, weight in enumerate(cfg.loss_params.vgg_layer_weight):
346
+ value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
347
+ loss_IN += weight * value
348
+ loss_IN /= sum(cfg.loss_params.vgg_layer_weight)
349
+ loss += loss_IN * cfg.loss_params.vgg_loss * adapted_weight
350
+ vgg_loss_accum += loss_IN.item()
351
+
352
+ # Process GAN loss if enabled
353
+ if cfg.loss_params.gan_loss > 0:
354
+ set_requires_grad(loss_dict['discriminator'], False)
355
+ loss_G = 0.
356
+ discriminator_maps_generated = loss_dict['discriminator'](pyramide_generated)
357
+ discriminator_maps_real = loss_dict['discriminator'](pyramide_real)
358
+
359
+ for scale in loss_dict['disc_scales']:
360
+ key = 'prediction_map_%s' % scale
361
+ value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
362
+ loss_G += value
363
+ gan_loss_accum += loss_G.item()
364
+
365
+ loss += loss_G * cfg.loss_params.gan_loss * get_ganloss_weight(global_step) * adapted_weight
366
+
367
+ # Process feature matching loss if enabled
368
+ if cfg.loss_params.fm_loss[0] > 0:
369
+ L_feature_matching = 0.
370
+ for scale in loss_dict['disc_scales']:
371
+ key = 'feature_maps_%s' % scale
372
+ for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
373
+ value = torch.abs(a - b).mean()
374
+ L_feature_matching += value * cfg.loss_params.fm_loss[i]
375
+ loss += L_feature_matching * adapted_weight
376
+ fm_loss_accum += L_feature_matching.item()
377
+
378
+ # Process mouth GAN loss if enabled
379
+ if cfg.loss_params.mouth_gan_loss > 0:
380
+ set_requires_grad(loss_dict['mouth_discriminator'], False)
381
+ loss_G = 0.
382
+ mouth_discriminator_maps_generated = loss_dict['mouth_discriminator'](pyramide_generated_mouth)
383
+ mouth_discriminator_maps_real = loss_dict['mouth_discriminator'](pyramide_real_mouth)
384
+
385
+ for scale in loss_dict['disc_scales']:
386
+ key = 'prediction_map_%s' % scale
387
+ value = ((1 - mouth_discriminator_maps_generated[key]) ** 2).mean()
388
+ loss_G += value
389
+ gan_loss_accum_mouth += loss_G.item()
390
+
391
+ loss += loss_G * cfg.loss_params.mouth_gan_loss * get_ganloss_weight(global_step) * adapted_weight
392
+
393
+ # Process feature matching loss for mouth if enabled
394
+ if cfg.loss_params.fm_loss[0] > 0:
395
+ L_feature_matching = 0.
396
+ for scale in loss_dict['disc_scales']:
397
+ key = 'feature_maps_%s' % scale
398
+ for i, (a, b) in enumerate(zip(mouth_discriminator_maps_real[key], mouth_discriminator_maps_generated[key])):
399
+ value = torch.abs(a - b).mean()
400
+ L_feature_matching += value * cfg.loss_params.fm_loss[i]
401
+ loss += L_feature_matching * adapted_weight
402
+ fm_loss_accum += L_feature_matching.item()
403
+
404
+ # Process sync loss if enabled
405
+ if cfg.loss_params.sync_loss > 0:
406
+ pred_frames = rearrange(
407
+ image_pred, '(b f) c h w-> b (f c) h w', f=pixel_values_backward.shape[1])
408
+ pred_frames = pred_frames[:, :, height // 2 :, :]
409
+ sync_loss, image_audio_sim_pred = get_sync_loss(
410
+ audio_embed,
411
+ gt_frames,
412
+ pred_frames,
413
+ syncnet,
414
+ adapted_weight,
415
+ frames_left_index=frames_left_index,
416
+ frames_right_index=frames_right_index,
417
+ )
418
+ sync_loss_accum += sync_loss.item()
419
+ loss += sync_loss * cfg.loss_params.sync_loss * adapted_weight
420
+
421
+ # Backward pass
422
+ avg_loss = accelerator.gather(loss.repeat(cfg.data.train_bs)).mean()
423
+ train_loss += avg_loss.item()
424
+ accelerator.backward(loss)
425
+
426
+ # Train discriminator if GAN loss is enabled
427
+ if cfg.loss_params.gan_loss > 0:
428
+ set_requires_grad(loss_dict['discriminator'], True)
429
+ loss_D = loss_dict['discriminator_full'](frames, image_pred.detach())
430
+ avg_loss_D = accelerator.gather(loss_D.repeat(cfg.data.train_bs)).mean()
431
+ train_loss_D += avg_loss_D.item() / 1
432
+ loss_D = loss_D * get_ganloss_weight(global_step) * adapted_weight
433
+ accelerator.backward(loss_D)
434
+
435
+ if accelerator.sync_gradients:
436
+ accelerator.clip_grad_norm_(
437
+ loss_dict['discriminator'].parameters(), cfg.solver.max_grad_norm)
438
+ if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0:
439
+ loss_dict['optimizer_D'].step()
440
+ loss_dict['scheduler_D'].step()
441
+ loss_dict['optimizer_D'].zero_grad()
442
+
443
+ # Train mouth discriminator if mouth GAN loss is enabled
444
+ if cfg.loss_params.mouth_gan_loss > 0:
445
+ set_requires_grad(loss_dict['mouth_discriminator'], True)
446
+ mouth_loss_D = loss_dict['mouth_discriminator_full'](
447
+ frames_mouth, image_pred_mouth.detach())
448
+ avg_mouth_loss_D = accelerator.gather(
449
+ mouth_loss_D.repeat(cfg.data.train_bs)).mean()
450
+ train_loss_D_mouth += avg_mouth_loss_D.item() / 1
451
+ mouth_loss_D = mouth_loss_D * get_ganloss_weight(global_step) * adapted_weight
452
+ accelerator.backward(mouth_loss_D)
453
+
454
+ if accelerator.sync_gradients:
455
+ accelerator.clip_grad_norm_(
456
+ loss_dict['mouth_discriminator'].parameters(), cfg.solver.max_grad_norm)
457
+ if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0:
458
+ loss_dict['mouth_optimizer_D'].step()
459
+ loss_dict['mouth_scheduler_D'].step()
460
+ loss_dict['mouth_optimizer_D'].zero_grad()
461
+
462
+ # Update main model
463
+ if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0:
464
+ if accelerator.sync_gradients:
465
+ accelerator.clip_grad_norm_(
466
+ model_dict['trainable_params'],
467
+ cfg.solver.max_grad_norm,
468
+ )
469
+ model_dict['optimizer'].step()
470
+ model_dict['lr_scheduler'].step()
471
+ model_dict['optimizer'].zero_grad()
472
+
473
+ # Update progress and log metrics
474
+ if accelerator.sync_gradients:
475
+ progress_bar.update(1)
476
+ global_step += 1
477
+ accelerator.log({
478
+ "train_loss": train_loss,
479
+ "train_loss_D": train_loss_D,
480
+ "train_loss_D_mouth": train_loss_D_mouth,
481
+ "l1_loss": l1_loss_accum,
482
+ "vgg_loss": vgg_loss_accum,
483
+ "gan_loss": gan_loss_accum,
484
+ "fm_loss": fm_loss_accum,
485
+ "sync_loss": sync_loss_accum,
486
+ "adapted_weight": adapted_weight_accum,
487
+ "lr": model_dict['lr_scheduler'].get_last_lr()[0],
488
+ }, step=global_step)
489
+
490
+ # Reset loss accumulators
491
+ train_loss = 0.0
492
+ l1_loss_accum = 0.0
493
+ vgg_loss_accum = 0.0
494
+ gan_loss_accum = 0.0
495
+ fm_loss_accum = 0.0
496
+ sync_loss_accum = 0.0
497
+ adapted_weight_accum = 0.0
498
+ train_loss_D = 0.0
499
+ train_loss_D_mouth = 0.0
500
+
501
+ # Run validation if needed
502
+ if global_step % cfg.val_freq == 0 or global_step == 10:
503
+ try:
504
+ validation(
505
+ cfg,
506
+ dataloader_dict['val_dataloader'],
507
+ model_dict['net'],
508
+ model_dict['vae'],
509
+ model_dict['wav2vec'],
510
+ accelerator,
511
+ save_dir,
512
+ global_step,
513
+ weight_dtype,
514
+ syncnet_score=adapted_weight,
515
+ )
516
+ except Exception as e:
517
+ print(f"An error occurred during validation: {e}")
518
+
519
+ # Save checkpoint if needed
520
+ if global_step % cfg.checkpointing_steps == 0:
521
+ save_path = os.path.join(save_dir, f"checkpoint-{global_step}")
522
+ try:
523
+ start_time = time.time()
524
+ if accelerator.is_main_process:
525
+ save_models(
526
+ accelerator,
527
+ model_dict['net'],
528
+ save_dir,
529
+ global_step,
530
+ cfg,
531
+ logger=logger
532
+ )
533
+ delete_additional_ckpt(save_dir, cfg.total_limit)
534
+ elapsed_time = time.time() - start_time
535
+ if elapsed_time > 300:
536
+ print(f"Skipping storage as it took too long in step {global_step}.")
537
+ else:
538
+ print(f"Resume states saved at {save_dir} successfully in {elapsed_time}s.")
539
+ except Exception as e:
540
+ print(f"Error when saving model in step {global_step}:", e)
541
+
542
+ # Update progress bar
543
+ t_model = time.time() - t_model_start
544
+ logs = {
545
+ "step_loss": loss.detach().item(),
546
+ "lr": model_dict['lr_scheduler'].get_last_lr()[0],
547
+ "td": f"{t_data:.2f}s",
548
+ "tm": f"{t_model:.2f}s",
549
+ }
550
+ t_data_start = time.time()
551
+ progress_bar.set_postfix(**logs)
552
+
553
+ if global_step >= cfg.solver.max_train_steps:
554
+ break
555
+
556
+ # Save model after each epoch
557
+ if (epoch + 1) % cfg.save_model_epoch_interval == 0:
558
+ try:
559
+ start_time = time.time()
560
+ if accelerator.is_main_process:
561
+ save_models(accelerator, model_dict['net'], save_dir, global_step, cfg)
562
+ accelerator.save_state(save_path)
563
+ elapsed_time = time.time() - start_time
564
+ if elapsed_time > 120:
565
+ print(f"Skipping storage as it took too long in step {global_step}.")
566
+ else:
567
+ print(f"Model saved successfully in {elapsed_time}s.")
568
+ except Exception as e:
569
+ print(f"Error when saving model in step {global_step}:", e)
570
+ accelerator.wait_for_everyone()
571
+
572
+ # End training
573
+ accelerator.end_training()
574
+
575
+ if __name__ == "__main__":
576
+ parser = argparse.ArgumentParser()
577
+ parser.add_argument("--config", type=str, default="./configs/training/stage2.yaml")
578
+ args = parser.parse_args()
579
+ config = OmegaConf.load(args.config)
580
+ main(config)
train.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # MuseTalk Training Script
4
+ # This script combines both training stages for the MuseTalk model
5
+ # Usage: sh train.sh [stage1|stage2]
6
+ # Example: sh train.sh stage1 # To run stage 1 training
7
+ # Example: sh train.sh stage2 # To run stage 2 training
8
+
9
+ # Check if stage argument is provided
10
+ if [ $# -ne 1 ]; then
11
+ echo "Error: Please specify the training stage"
12
+ echo "Usage: ./train.sh [stage1|stage2]"
13
+ exit 1
14
+ fi
15
+
16
+ STAGE=$1
17
+
18
+ # Validate stage argument
19
+ if [ "$STAGE" != "stage1" ] && [ "$STAGE" != "stage2" ]; then
20
+ echo "Error: Invalid stage. Must be either 'stage1' or 'stage2'"
21
+ exit 1
22
+ fi
23
+
24
+ # Launch distributed training using accelerate
25
+ # --config_file: Path to the GPU configuration file
26
+ # --main_process_port: Port number for the main process, used for distributed training communication
27
+ # train.py: Training script
28
+ # --config: Path to the training configuration file
29
+ echo "Starting $STAGE training..."
30
+ accelerate launch --config_file ./configs/training/gpu.yaml \
31
+ --main_process_port 29502 \
32
+ train.py --config ./configs/training/$STAGE.yaml
33
+
34
+ echo "Training completed for $STAGE"