tedlasai commited on
Commit
199f9c2
·
1 Parent(s): 5aaa283
LICENSE-CODE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Stability AI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py CHANGED
@@ -1,7 +1,118 @@
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "hi " + name + "!!"
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ import os
2
+ import uuid
3
+ from pathlib import Path
4
+ import argparse
5
+
6
  import gradio as gr
7
+ from PIL import Image
8
+ from diffusers.utils import export_to_video
9
+
10
+ from inference import load_model, inference_on_image
11
+
12
+ # -----------------------
13
+ # 1. Load model
14
+ # -----------------------
15
+ args = argparse.Namespace()
16
+ args.blur2vid_hf_repo_path = "tedlasai/blur2vid"
17
+ args.pretrained_model_path = "THUDM/CogVideoX-2b"
18
+ args.model_config_path = "training/configs/outsidephotos.yaml"
19
+ args.video_width = 1280
20
+ args.video_height = 720
21
+ args.seed = None
22
+
23
+ pipe, model_config = load_model(args)
24
+
25
+ OUTPUT_DIR = Path("/tmp/generated_videos")
26
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
27
+
28
+
29
+ def generate_video_from_image(image: Image.Image, interval_key: str, num_inference_steps: int) -> str:
30
+ """
31
+ Wrapper for Gradio. Takes an image and returns a video path.
32
+ """
33
+ if image is None:
34
+ raise gr.Error("Please upload an image first.")
35
+
36
+ print("Generating video")
37
+ import torch
38
+ print("CUDA:", torch.cuda.is_available())
39
+ print("Device:", torch.cuda.get_device_name(0))
40
+ print("bf16 supported:", torch.cuda.is_bf16_supported())
41
+
42
+ args.num_inference_steps = num_inference_steps
43
+
44
+ video_id = uuid.uuid4().hex
45
+ output_path = OUTPUT_DIR / f"{video_id}.mp4"
46
+
47
+ args.device = "cuda"
48
+
49
+ pipe.to(args.device)
50
+ processed_image, video = inference_on_image(pipe, image, interval_key, model_config, args)
51
+ export_to_video(video, output_path, fps=20)
52
+
53
+ if not os.path.exists(output_path):
54
+ raise gr.Error("Video generation failed: output file not found.")
55
+
56
+ return str(output_path)
57
+
58
+
59
+ with gr.Blocks(css="footer {visibility: hidden}") as demo:
60
+ gr.Markdown(
61
+ """
62
+ # 🖼️ ➜ 🎬 Recover Motion from a Blurry Image
63
+
64
+ This demo accompanies the paper **“Generating the Past, Present, and Future from a Motion-Blurred Image”**
65
+ by Tedla *et al.*, ACM Transactions on Graphics (SIGGRAPH Asia 2025).
66
+
67
+ - 🌐 **Project page:** <https://blur2vid.github.io/>
68
+ - 💻 **Code:** <https://github.com/tedlasai/blur2vid/>
69
+
70
+ Upload a blurry image and the model will generate a short video showing the recovered motion based on your selection.
71
+ Note: The image will be resized to 1280×720. We recommend uploading landscape-oriented images.
72
+ """
73
+ )
74
+
75
+ with gr.Row():
76
+ with gr.Column():
77
+ image_in = gr.Image(
78
+ type="pil",
79
+ label="Input image",
80
+ interactive=True,
81
+ )
82
+
83
+ with gr.Row():
84
+ tense_choice = gr.Radio(
85
+ label="Select the interval to be generated:",
86
+ choices=["present", "past, present and future"],
87
+ value="past, present and future",
88
+ interactive=True,
89
+ )
90
+
91
+ num_inference_steps = gr.Slider(
92
+ label="Number of inference steps",
93
+ minimum=4,
94
+ maximum=50,
95
+ step=1,
96
+ value=20,
97
+ info="More steps = better quality but slower",
98
+ )
99
+
100
+ generate_btn = gr.Button("Generate video", variant="primary")
101
+
102
+ with gr.Column():
103
+ video_out = gr.Video(
104
+ label="Generated video",
105
+ format="mp4",
106
+ autoplay=True,
107
+ loop=True,
108
+ )
109
 
110
+ generate_btn.click(
111
+ fn=generate_video_from_image,
112
+ inputs=[image_in, tense_choice, num_inference_steps],
113
+ outputs=video_out,
114
+ api_name="predict",
115
+ )
116
 
117
+ if __name__ == "__main__":
118
+ demo.launch()
extra/compute_metrics.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchmetrics
2
+ import os
3
+ import torch
4
+ from PIL import Image
5
+ import numpy as np
6
+ import csv
7
+ import sys
8
+
9
+ num_positions = 9
10
+
11
+ output_dir_path = "/datasets/sai/focal-burst-learning/metrics_output"
12
+
13
+
14
+
15
+ gt = "gt"
16
+ model = sys.argv[1]
17
+
18
+ gt_path = os.path.join(output_dir_path, gt)
19
+ model_path = os.path.join(output_dir_path, model)
20
+
21
+ device = sys.argv[2]
22
+
23
+ metrics_grid = []
24
+
25
+ for i in range(num_positions):
26
+ row = []
27
+ for j in range(num_positions):
28
+ metrics = {
29
+ "psnr": torchmetrics.image.PeakSignalNoiseRatio(data_range=1.0).to(device),
30
+ "ssim": torchmetrics.image.StructuralSimilarityIndexMeasure().to(device),
31
+ "lpips": torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type='vgg', normalize=True).to(device),
32
+ "fid": torchmetrics.image.fid.FrechetInceptionDistance(normalize=True).to(device),
33
+ "vif": torchmetrics.image.VisualInformationFidelity().to(device),
34
+ }
35
+ row.append(metrics)
36
+ metrics_grid.append(row)
37
+ print("Created metrics for position", i)
38
+
39
+ #lopp through each directory in gt_path
40
+ #get all directories in gt_path
41
+ position_dirs = os.listdir(gt_path)
42
+ position_dirs = sorted([dir for dir in position_dirs if os.path.isdir(os.path.join(gt_path, dir))]) [0:num_positions]
43
+
44
+ for gt_dir in position_dirs:
45
+ position_number = int(gt_dir.split("_")[1])
46
+ #get pngs inside that directory
47
+ gt_pngs = sorted(os.listdir(os.path.join(gt_path, gt_dir, "images")))
48
+ #Confirm that number of pngs == 164*9
49
+ assert len(gt_pngs) == 164*9
50
+ #loop through the 164 imgs
51
+ for i in range(164):
52
+ #get the 9 frames
53
+ gt_frames_names = gt_pngs[i*9:(i+1)*9]
54
+ #load the 9 frames
55
+ gt_frames = [Image.open(os.path.join(gt_path, gt_dir, "images", frame)) for frame in gt_frames_names]
56
+ #make into numpy arraymo
57
+ gt_frames = [torch.tensor(np.array(frame)/255).to(torch.float32).to(device).permute(2,0,1).unsqueeze(0) for frame in gt_frames]
58
+
59
+ #load model_frames which is almost smae path but in model_path
60
+ model_frames = [Image.open(os.path.join(model_path, gt_dir, "images", frame)) for frame in gt_frames_names]
61
+ #make into numpy array
62
+ model_frames = [torch.tensor(np.array(frame)/255).to(torch.float32).to(device).permute(2,0,1).unsqueeze(0) for frame in model_frames]
63
+
64
+ #loop through the 9 frames
65
+ for j in range(num_positions):
66
+ #compute metrics
67
+ for key, metric in metrics_grid[position_number][j].items():
68
+ #if frames have a 4th channel discard it
69
+ if gt_frames[j].shape[1] == 4:
70
+ gt_frames[j] = gt_frames[j][:,:3,:,:]
71
+ if model_frames[j].shape[1] == 4:
72
+ model_frames[j] = model_frames[j][:,:3,:,:]
73
+ if key == "fid":
74
+ metric.update(model_frames[j], real=False)
75
+ metric.update(gt_frames[j], real=True)
76
+ else:
77
+ metric(gt_frames[j], model_frames[j])
78
+
79
+ print("Computed metrics for position", position_number, "frame", i)
80
+
81
+ #write the metrics to a csv (each metric as a csv)
82
+
83
+ def write_metrics_to_csv(metrics_grid, metric_names, formatting_options=None, output_dir="metrics_output"):
84
+ """
85
+ Writes each metric in the metrics_grid to a separate CSV file.
86
+
87
+ Args:
88
+ metrics_grid (list): A 9x9 list of dictionaries containing metrics.
89
+ metric_names (list): List of metric names (e.g., ["psnr", "lpips", "fid"]).
90
+ output_dir (str): Directory where the CSV files will be saved.
91
+ """
92
+ import os
93
+ os.makedirs(output_dir, exist_ok=True) # Create output directory if it doesn't exist
94
+
95
+ positions = list(range(1, num_positions+1))
96
+
97
+ for metric_name in metric_names:
98
+ output_file = os.path.join(output_dir, f"{metric_name}.csv")
99
+
100
+ # Get the formatting function for the current metric, or use default
101
+ format_fn = formatting_options.get(metric_name, lambda x: f"{x}") if formatting_options else lambda x: f"{x}"
102
+
103
+
104
+ # Write the metric to the CSV
105
+ with open(output_file, mode='w', newline='') as csv_file:
106
+ writer = csv.writer(csv_file)
107
+
108
+ header = ["Starting Position/End Position"] + [f"Position {i}" for i in positions]
109
+ writer.writerow(header)
110
+
111
+ # Iterate over the grid and extract the metric values
112
+ for i, row in enumerate(metrics_grid):
113
+ csv_row = [f"Position {positions[i]}"] # Add the column label as the first column
114
+ for cell in row:
115
+ metric = cell[metric_name]
116
+ # Assuming metrics are PyTorch objects with a `compute` method
117
+ # Replace `0.0` with metric.compute() if metric values are computed
118
+ value = 0.0 if not hasattr(metric, "compute") else metric.compute().item()
119
+ csv_row.append(format_fn(value)) # Format the value
120
+ writer.writerow(csv_row)
121
+ print(f"Wrote row for position {positions[i]} with metric {metric_name}")
122
+
123
+ print(f"Saved {metric_name} metrics to {output_file}")
124
+
125
+ formatting_options = {
126
+ "psnr": lambda x: f"{x:.2f}", # Two decimal places
127
+ "lpips": lambda x: f"{x:.4f}", # Four decimal places
128
+ "fid": lambda x: f"{x:.2f}", # Two decimal places
129
+ "ssim": lambda x: f"{x:.4f}", # Four decimal places
130
+ "vif": lambda x: f"{x:.4f}" # Four decimal places
131
+ }
132
+
133
+
134
+
135
+ write_metrics_to_csv(metrics_grid, ["psnr", "ssim", "lpips", "fid", "vif"], formatting_options=formatting_options, output_dir=f"{output_dir_path}/metrics_output/{model}")
136
+
extra/download_dataset.py ADDED
File without changes
setup/checkpoints_to_hf.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import HfApi
2
+ import os
3
+ #run with HF_TOKEN = your_hf_token before python_command
4
+ api = HfApi(token=os.getenv("HF_TOKEN"))
5
+ folders = ["/datasets/sai/focal-burst-learning/svd/checkpoints/checkpoint-200000"]
6
+ for folder in folders:
7
+ api.upload_folder(
8
+ folder_path=folder,
9
+ repo_id="tedlasai/learn2refocus",
10
+ repo_type="model",
11
+ path_in_repo=os.path.basename(folder)
12
+ )
setup/download_checkpoints.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+ import os
3
+ import sys
4
+ # Make sure HF_TOKEN is set in your env beforehand:
5
+ # export HF_TOKEN=your_hf_token
6
+ #get first command line argument
7
+
8
+
9
+ mode = sys.argv[1] if len(sys.argv) > 1 else "outsidephotos"
10
+
11
+
12
+ REPO_ID = "tedlasai/learn2refocus"
13
+ REPO_TYPE = "model"
14
+
15
+
16
+ checkpoints = [
17
+ "checkpoint-200000",
18
+ ]
19
+
20
+ # This is the root local directory where you want everything saved
21
+ #get path of this file
22
+ LOCAL_TRAINING_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "checkpoints")
23
+ os.makedirs(LOCAL_TRAINING_ROOT, exist_ok=True)
24
+
25
+ # Download only those folders from the repo and place them under LOCAL_TRAINING_ROOT
26
+ snapshot_download(
27
+ repo_id=REPO_ID,
28
+ repo_type=REPO_TYPE,
29
+ local_dir=LOCAL_TRAINING_ROOT,
30
+ local_dir_use_symlinks=False,
31
+ allow_patterns=[f"{name}/*" for name in checkpoints],
32
+ token=os.getenv("HF_TOKEN"),
33
+ )
34
+
35
+ print(f"Done! Checkpoints downloaded under: {LOCAL_TRAINING_ROOT}")
setup/download_svd_weights.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+
3
+ save_dir = "./svdh"
4
+
5
+ # 1. Download the full model repo (weights + config + assets)
6
+ local_dir = snapshot_download(
7
+ repo_id="stabilityai/stable-video-diffusion-img2vid",
8
+ revision="main",
9
+ local_dir=save_dir,
10
+ local_dir_use_symlinks=False # ensures files are fully copied, not symlinked
11
+ )
12
+
13
+ print(f"Model downloaded to: {local_dir}")
setup/environment.yaml ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: refocus
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - _openmp_mutex=5.1=1_gnu
8
+ - asttokens=3.0.0=pyhd8ed1ab_1
9
+ - bzip2=1.0.8=h5eee18b_6
10
+ - ca-certificates=2025.4.26=hbd8a1cb_0
11
+ - comm=0.2.2=pyhd8ed1ab_1
12
+ - debugpy=1.6.0=py310hd8f1fbe_0
13
+ - entrypoints=0.4=pyhd8ed1ab_1
14
+ - exceptiongroup=1.2.2=pyhd8ed1ab_1
15
+ - executing=2.2.0=pyhd8ed1ab_0
16
+ - ffmpeg=4.3.2=hca11adc_0
17
+ - freetype=2.10.4=h0708190_1
18
+ - gmp=6.2.1=h58526e2_0
19
+ - gnutls=3.6.13=h85f3911_1
20
+ - ipykernel=6.20.2=pyh210e3f2_0
21
+ - ipython=8.36.0=pyh907856f_0
22
+ - jedi=0.19.2=pyhd8ed1ab_1
23
+ - jupyter_client=7.3.4=pyhd8ed1ab_0
24
+ - jupyter_core=5.7.2=pyh31011fe_1
25
+ - lame=3.100=h7f98852_1001
26
+ - ld_impl_linux-64=2.40=h12ee557_0
27
+ - libevent=2.1.12=hdbd6064_1
28
+ - libffi=3.4.4=h6a678d5_1
29
+ - libgcc-ng=11.2.0=h1234567_1
30
+ - libgomp=11.2.0=h1234567_1
31
+ - libpng=1.6.37=h21135ba_2
32
+ - libsodium=1.0.18=h36c2ea0_1
33
+ - libstdcxx-ng=11.2.0=h1234567_1
34
+ - libuuid=1.41.5=h5eee18b_0
35
+ - matplotlib-inline=0.1.7=pyhd8ed1ab_1
36
+ - ncurses=6.4=h6a678d5_0
37
+ - nest-asyncio=1.6.0=pyhd8ed1ab_1
38
+ - nettle=3.6=he412f7d_0
39
+ - openh264=2.1.1=h780b84a_0
40
+ - openssl=3.0.16=h5eee18b_0
41
+ - parso=0.8.4=pyhd8ed1ab_1
42
+ - pexpect=4.9.0=pyhd8ed1ab_1
43
+ - pickleshare=0.7.5=pyhd8ed1ab_1004
44
+ - pip=25.0=py310h06a4308_0
45
+ - platformdirs=4.3.7=pyh29332c3_0
46
+ - prompt-toolkit=3.0.51=pyha770c72_0
47
+ - ptyprocess=0.7.0=pyhd8ed1ab_1
48
+ - pure_eval=0.2.3=pyhd8ed1ab_1
49
+ - pygments=2.19.1=pyhd8ed1ab_0
50
+ - python=3.10.16=he870216_1
51
+ - python-dateutil=2.9.0.post0=pyhff2d567_1
52
+ - python_abi=3.10=2_cp310
53
+ - pyzmq=23.0.0=py310h330234f_0
54
+ - readline=8.2=h5eee18b_0
55
+ - setuptools=75.8.0=py310h06a4308_0
56
+ - six=1.17.0=pyhd8ed1ab_0
57
+ - sqlite=3.45.3=h5eee18b_0
58
+ - stack_data=0.6.3=pyhd8ed1ab_1
59
+ - tk=8.6.14=h39e8969_0
60
+ - tmux=3.3a=h5eee18b_1
61
+ - tornado=6.1=py310h5764c6d_3
62
+ - traitlets=5.14.3=pyhd8ed1ab_1
63
+ - typing_extensions=4.13.2=pyh29332c3_0
64
+ - wcwidth=0.2.13=pyhd8ed1ab_1
65
+ - wheel=0.45.1=py310h06a4308_0
66
+ - x264=1!161.3030=h7f98852_1
67
+ - xz=5.6.4=h5eee18b_1
68
+ - zeromq=4.3.4=h9c3ff4c_1
69
+ - zlib=1.2.13=h5eee18b_1
70
+ - pip:
71
+ - absl-py==2.2.0
72
+ - accelerate==1.5.2
73
+ - aiofiles==23.2.1
74
+ - aiohappyeyeballs==2.6.1
75
+ - aiohttp==3.12.14
76
+ - aiosignal==1.4.0
77
+ - annotated-types==0.7.0
78
+ - anyio==4.9.0
79
+ - async-timeout==5.0.1
80
+ - atomicwrites==1.4.1
81
+ - attrs==25.3.0
82
+ - beautifulsoup4==4.13.4
83
+ - certifi==2025.1.31
84
+ - cffi==1.17.1
85
+ - charset-normalizer==3.4.1
86
+ - click==8.1.8
87
+ - colour-science==0.4.6
88
+ - contourpy==1.3.1
89
+ - controlnet-aux==0.0.9
90
+ - cycler==0.12.1
91
+ - decorator==4.4.2
92
+ - decord==0.6.0
93
+ - denku==0.0.51
94
+ - diffusers==0.32.0
95
+ - distro==1.9.0
96
+ - docker-pycreds==0.4.0
97
+ - einops==0.8.1
98
+ - einops-exts==0.0.4
99
+ - fastapi==0.115.11
100
+ - ffmpeg-python==0.2.0
101
+ - ffmpy==0.5.0
102
+ - filelock==3.18.0
103
+ - flatbuffers==25.2.10
104
+ - fonttools==4.56.0
105
+ - frozenlist==1.7.0
106
+ - fsspec==2025.3.0
107
+ - future==1.0.0
108
+ - gdown==5.2.0
109
+ - gitdb==4.0.12
110
+ - gitpython==3.1.44
111
+ - gradio==5.22.0
112
+ - gradio-client==1.8.0
113
+ - groovy==0.1.2
114
+ - h11==0.14.0
115
+ - hf-transfer==0.1.9
116
+ - httpcore==1.0.7
117
+ - httpx==0.28.1
118
+ - huggingface-hub==0.29.3
119
+ - idna==3.10
120
+ - imageio==2.37.0
121
+ - imageio-ffmpeg==0.6.0
122
+ - importlib-metadata==8.6.1
123
+ - jax==0.5.3
124
+ - jaxlib==0.5.3
125
+ - jinja2==3.1.6
126
+ - jiter==0.9.0
127
+ - kiwisolver==1.4.8
128
+ - lazy-loader==0.4
129
+ - lightning==2.5.2
130
+ - lightning-utilities==0.14.3
131
+ - markdown-it-py==3.0.0
132
+ - markupsafe==3.0.2
133
+ - matplotlib==3.10.1
134
+ - mdurl==0.1.2
135
+ - mediapipe==0.10.21
136
+ - ml-dtypes==0.5.1
137
+ - moviepy==1.0.3
138
+ - mpmath==1.3.0
139
+ - multidict==6.6.3
140
+ - networkx==3.4.2
141
+ - numpy==1.26.0
142
+ - nvidia-cublas-cu12==12.4.5.8
143
+ - nvidia-cuda-cupti-cu12==12.4.127
144
+ - nvidia-cuda-nvrtc-cu12==12.4.127
145
+ - nvidia-cuda-runtime-cu12==12.4.127
146
+ - nvidia-cudnn-cu12==9.1.0.70
147
+ - nvidia-cufft-cu12==11.2.1.3
148
+ - nvidia-curand-cu12==10.3.5.147
149
+ - nvidia-cusolver-cu12==11.6.1.9
150
+ - nvidia-cusparse-cu12==12.3.1.170
151
+ - nvidia-cusparselt-cu12==0.6.2
152
+ - nvidia-ml-py==12.570.86
153
+ - nvidia-nccl-cu12==2.21.5
154
+ - nvidia-nvjitlink-cu12==12.4.127
155
+ - nvidia-nvtx-cu12==12.4.127
156
+ - nvitop==1.4.2
157
+ - openai==1.68.2
158
+ - opencv-contrib-python==4.11.0.86
159
+ - opencv-python==4.11.0.86
160
+ - opencv-python-headless==4.11.0.86
161
+ - opt-einsum==3.4.0
162
+ - orjson==3.10.15
163
+ - packaging==24.2
164
+ - pandas==2.2.3
165
+ - peft==0.15.0
166
+ - pillow==9.5.0
167
+ - proglog==0.1.10
168
+ - propcache==0.3.2
169
+ - protobuf==4.25.6
170
+ - psutil==5.9.8
171
+ - ptflops==0.7.4
172
+ - pycparser==2.22
173
+ - pydantic==2.10.6
174
+ - pydantic-core==2.27.2
175
+ - pydub==0.25.1
176
+ - pyparsing==3.2.1
177
+ - pysocks==1.7.1
178
+ - python-dotenv==1.0.1
179
+ - python-multipart==0.0.20
180
+ - pytorch-lightning==2.5.2
181
+ - pytz==2025.1
182
+ - pyyaml==6.0.2
183
+ - regex==2024.11.6
184
+ - requests==2.32.3
185
+ - rich==13.9.4
186
+ - ruff==0.11.2
187
+ - safehttpx==0.1.6
188
+ - safetensors==0.5.3
189
+ - scikit-image==0.24.0
190
+ - scikit-video==1.1.11
191
+ - scipy==1.15.2
192
+ - semantic-version==2.10.0
193
+ - sentencepiece==0.2.0
194
+ - sentry-sdk==2.24.0
195
+ - setproctitle==1.3.5
196
+ - shellingham==1.5.4
197
+ - smmap==5.0.2
198
+ - sniffio==1.3.1
199
+ - sounddevice==0.5.1
200
+ - soupsieve==2.7
201
+ - spaces==0.32.0
202
+ - spandrel==0.4.1
203
+ - starlette==0.46.1
204
+ - sympy==1.13.1
205
+ - tifffile==2025.3.13
206
+ - timm==0.6.7
207
+ - tokenizers==0.21.1
208
+ - tomlkit==0.13.2
209
+ - torch==2.6.0
210
+ - torch-fidelity==0.3.0
211
+ - torchmetrics==1.7.4
212
+ - torchvision==0.21.0
213
+ - tqdm==4.67.1
214
+ - transformers==4.50.0
215
+ - triton==3.2.0
216
+ - typer==0.15.2
217
+ - typing-extensions==4.12.2
218
+ - tzdata==2025.1
219
+ - urllib3==2.3.0
220
+ - uvicorn==0.34.0
221
+ - videoio==0.3.0
222
+ - wandb==0.19.8
223
+ - websockets==15.0.1
224
+ - yarl==1.20.1
225
+ - zipp==3.21.0
simplified_inference.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """Script to fine-tune Stable Video Diffusion."""
18
+
19
+ import math
20
+ import os
21
+ from torch.utils.data import Dataset
22
+ import accelerate
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from accelerate.logging import get_logger
28
+ from accelerate.utils import set_seed
29
+ from packaging import version
30
+ from tqdm.auto import tqdm
31
+ from transformers import CLIPVisionModelWithProjection
32
+ from simplified_validation import valid_net
33
+ from diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
34
+ from diffusers.utils import check_min_version
35
+ import argparse
36
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
37
+ check_min_version("0.24.0.dev0")
38
+
39
+ logger = get_logger(__name__, log_level="INFO")
40
+ import numpy as np
41
+ import torch
42
+ import os
43
+ import glob
44
+
45
+
46
+
47
+ def parse_args():
48
+ parser = argparse.ArgumentParser(description="SVD Training Script")
49
+ parser.add_argument(
50
+ "--config",
51
+ type=str,
52
+ default="/datasets/sai/focal-burst-learning/svd/training/configs/outside_photos.yaml",
53
+ help="Path to the config file.",
54
+ )
55
+ #seed should be int that default 0 (optional)
56
+
57
+ parser.add_argument(
58
+ "--image_path",
59
+ type=str,
60
+ required=True,
61
+ help="Path to image input or directory containing input images",
62
+ )
63
+ parser.add_argument(
64
+ "--seed",
65
+ type=int,
66
+ default=0,
67
+ help="A seed for reproducible training.",
68
+ )
69
+
70
+ parser.add_argument(
71
+ "--learn2refocus_hf_repo_path",
72
+ type=str,
73
+ default="tedlasai/learn2refocus",
74
+ help="hf repo containing the weight files",
75
+ )
76
+
77
+ parser.add_argument(
78
+ "--pretrained_model_path",
79
+ type=str,
80
+ default="stabilityai/stable-video-diffusion-img2vid",
81
+ help="repo id or path for pretrained StableVideo Diffusion model",
82
+ )
83
+ parser.add_argument(
84
+ "--output_dir",
85
+ type=str,
86
+ default="outputs/simple_inference",
87
+ help="path to output",
88
+ )
89
+
90
+ parser.add_argument(
91
+ "--num_inference_steps",
92
+ type=int,
93
+ default=25,
94
+ help="number of DDPM steps",
95
+ )
96
+
97
+ parser.add_argument(
98
+ "--device",
99
+ type=str,
100
+ default="cuda",
101
+ help="inference device",
102
+ )
103
+
104
+
105
+ args = parser.parse_args()
106
+
107
+ return args
108
+
109
+
110
+
111
+ def find_scale(height, width):
112
+ max_pixels = 500000
113
+
114
+ # Start with no scaling
115
+ scale = 1.0
116
+
117
+ while True:
118
+ # Calculate the scaled dimensions
119
+ scaled_height = math.floor((height * scale) / 64) * 64
120
+ scaled_width = math.floor((width * scale) / 64) * 64
121
+
122
+ # Check if the scaled dimensions meet the pixel constraint
123
+ if scaled_height * scaled_width <= max_pixels:
124
+ return scaled_height, scaled_width
125
+
126
+ # Reduce the scale slightly
127
+ scale -= 0.01
128
+
129
+ def convert_to_batch(image, input_focal_position, sample_frames=9):
130
+ scene, focal_stack_num = image, input_focal_position
131
+ from PIL import Image
132
+ with Image.open(scene) as img:
133
+
134
+ icc_profile = img.info.get("icc_profile")
135
+ if icc_profile is None:
136
+ icc_profile = "none"
137
+ original_pixels = torch.from_numpy(np.array(img)).float().permute(2,0,1)
138
+ original_pixels = original_pixels / 255
139
+ width, height = img.size
140
+ scaled_width, scaled_height = find_scale(width, height)
141
+
142
+ img_resized = img.resize((scaled_width, scaled_height))
143
+ img_tensor = torch.from_numpy(np.array(img_resized)).float()
144
+ img_normalized = img_tensor / 127.5 - 1
145
+ img_normalized = img_normalized.permute(2, 0, 1)
146
+
147
+ pixels = torch.zeros((1, sample_frames, 3, scaled_height, scaled_width))
148
+ pixels[0, focal_stack_num] = img_normalized
149
+
150
+ name = os.path.splitext(os.path.basename(scene))[0]
151
+ return {"pixel_values": pixels, "focal_stack_num": focal_stack_num, "original_pixel_values": original_pixels, 'icc_profile': icc_profile, "name": name}
152
+
153
+ def main():
154
+ args = parse_args()
155
+
156
+ if args.seed is not None:
157
+ set_seed(args.seed)
158
+
159
+ if args.output_dir is not None:
160
+ os.makedirs(args.output_dir, exist_ok=True)
161
+
162
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
163
+
164
+ # inference-only modules
165
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
166
+ args.pretrained_model_path, subfolder="image_encoder"
167
+ )
168
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(
169
+ args.pretrained_model_path, subfolder="vae", variant="fp16"
170
+ )
171
+
172
+ weight_dtype = torch.float32
173
+ image_encoder.requires_grad_(False).to(device, dtype=weight_dtype)
174
+ vae.requires_grad_(False).to(device, dtype=weight_dtype)
175
+
176
+ # ---- load UNet from checkpoint root (this reads unet/config.json + diffusion_pytorch_model.safetensors)
177
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(
178
+ args.learn2refocus_hf_repo_path, subfolder="checkpoint-200000/unet"
179
+ ).to(device)
180
+
181
+ batch = convert_to_batch(args.image_path, input_focal_position=6)
182
+
183
+ unet.eval(); image_encoder.eval(); vae.eval()
184
+ with torch.no_grad():
185
+ valid_net(args, batch, unet, image_encoder, vae, 0, weight_dtype, device, num_inference_steps=args.num_inference_steps)
186
+
187
+ if __name__ == "__main__":
188
+ main()
189
+
190
+
simplified_pipeline.py ADDED
@@ -0,0 +1,807 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from dataclasses import dataclass
17
+ import random
18
+ from typing import Callable, Dict, List, Optional, Union
19
+
20
+ import numpy as np
21
+ import PIL.Image
22
+ import torch
23
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
24
+
25
+ from diffusers.image_processor import PipelineImageInput
26
+ from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
27
+ from diffusers.schedulers import EulerDiscreteScheduler
28
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
29
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
30
+ from diffusers.video_processor import VideoProcessor
31
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
32
+ import torch.nn.functional as F
33
+ from tqdm import tqdm
34
+ from einops import rearrange
35
+
36
+
37
+ def tensor_to_vae_latent(t, vae, otype="sample"):
38
+ video_length = t.shape[1]
39
+
40
+ t = rearrange(t, "b f c h w -> (b f) c h w")
41
+ if otype == "sample":
42
+ latents = vae.encode(t).latent_dist.sample()
43
+ else:
44
+ latents = vae.encode(t).latent_dist.mode()
45
+ latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length)
46
+ latents = latents * vae.config.scaling_factor
47
+
48
+ return latents
49
+
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+ EXAMPLE_DOC_STRING = """
54
+ Examples:
55
+ ```py
56
+ >>> from diffusers import StableVideoDiffusionPipeline
57
+ >>> from diffusers.utils import load_image, export_to_video
58
+
59
+ >>> pipe = StableVideoDiffusionPipeline.from_pretrained(
60
+ ... "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
61
+ ... )
62
+ >>> pipe.to("cuda")
63
+
64
+ >>> image = load_image(
65
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd-docstring-example.jpeg"
66
+ ... )
67
+ >>> image = image.resize((1024, 576))
68
+
69
+ >>> frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
70
+ >>> export_to_video(frames, "generated.mp4", fps=7)
71
+ ```
72
+ """
73
+
74
+
75
+ def _append_dims(x, target_dims):
76
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
77
+ dims_to_append = target_dims - x.ndim
78
+ if dims_to_append < 0:
79
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
80
+ return x[(...,) + (None,) * dims_to_append]
81
+
82
+
83
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
84
+ def retrieve_timesteps(
85
+ scheduler,
86
+ num_inference_steps: Optional[int] = None,
87
+ device: Optional[Union[str, torch.device]] = None,
88
+ timesteps: Optional[List[int]] = None,
89
+ sigmas: Optional[List[float]] = None,
90
+ **kwargs,
91
+ ):
92
+ r"""
93
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
94
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
95
+
96
+ Args:
97
+ scheduler (`SchedulerMixin`):
98
+ The scheduler to get timesteps from.
99
+ num_inference_steps (`int`):
100
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
101
+ must be `None`.
102
+ device (`str` or `torch.device`, *optional*):
103
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
104
+ timesteps (`List[int]`, *optional*):
105
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
106
+ `num_inference_steps` and `sigmas` must be `None`.
107
+ sigmas (`List[float]`, *optional*):
108
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
109
+ `num_inference_steps` and `timesteps` must be `None`.
110
+
111
+ Returns:
112
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
113
+ second element is the number of inference steps.
114
+ """
115
+ if timesteps is not None and sigmas is not None:
116
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
117
+ if timesteps is not None:
118
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
119
+ if not accepts_timesteps:
120
+ raise ValueError(
121
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
122
+ f" timestep schedules. Please check whether you are using the correct scheduler."
123
+ )
124
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
125
+ timesteps = scheduler.timesteps
126
+ num_inference_steps = len(timesteps)
127
+ elif sigmas is not None:
128
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
129
+ if not accept_sigmas:
130
+ raise ValueError(
131
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
132
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
133
+ )
134
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
135
+ timesteps = scheduler.timesteps
136
+ num_inference_steps = len(timesteps)
137
+ else:
138
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
139
+ timesteps = scheduler.timesteps
140
+ return timesteps, num_inference_steps
141
+
142
+
143
+ @dataclass
144
+ class StableVideoDiffusionPipelineOutput(BaseOutput):
145
+ r"""
146
+ Output class for Stable Video Diffusion pipeline.
147
+
148
+ Args:
149
+ frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]):
150
+ List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
151
+ num_frames, height, width, num_channels)`.
152
+ """
153
+
154
+ frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor]
155
+
156
+
157
+ class StableVideoDiffusionPipeline(DiffusionPipeline):
158
+ r"""
159
+ Pipeline to generate video from an input image using Stable Video Diffusion.
160
+
161
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
162
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
163
+
164
+ Args:
165
+ vae ([`AutoencoderKLTemporalDecoder`]):
166
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
167
+ image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
168
+ Frozen CLIP image-encoder
169
+ ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
170
+ unet ([`UNetSpatioTemporalConditionModel`]):
171
+ A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
172
+ scheduler ([`EulerDiscreteScheduler`]):
173
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
174
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
175
+ A `CLIPImageProcessor` to extract features from generated images.
176
+ """
177
+
178
+ model_cpu_offload_seq = "image_encoder->unet->vae"
179
+ _callback_tensor_inputs = ["latents"]
180
+
181
+ def __init__(
182
+ self,
183
+ vae: AutoencoderKLTemporalDecoder,
184
+ image_encoder: CLIPVisionModelWithProjection,
185
+ unet: UNetSpatioTemporalConditionModel,
186
+ scheduler: EulerDiscreteScheduler,
187
+ feature_extractor: CLIPImageProcessor,
188
+ ):
189
+ super().__init__()
190
+
191
+ self.register_modules(
192
+ vae=vae,
193
+ image_encoder=image_encoder,
194
+ unet=unet,
195
+ scheduler=scheduler,
196
+ feature_extractor=feature_extractor,
197
+ )
198
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
199
+ self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor)
200
+
201
+
202
+
203
+ def _encode_image(
204
+ self,
205
+ image: PipelineImageInput,
206
+ device: Union[str, torch.device],
207
+ num_videos_per_prompt: int,
208
+ do_classifier_free_guidance: bool,
209
+ ) -> torch.Tensor:
210
+ dtype = next(self.image_encoder.parameters()).dtype
211
+
212
+ if not isinstance(image, torch.Tensor):
213
+ image = self.video_processor.pil_to_numpy(image)
214
+ image = self.video_processor.numpy_to_pt(image)
215
+
216
+ # We normalize the image before resizing to match with the original implementation.
217
+ # Then we unnormalize it after resizing.
218
+ image = image * 2.0 - 1.0
219
+ image = _resize_with_antialiasing(image, (224, 224))
220
+ image = (image + 1.0) / 2.0
221
+
222
+
223
+ # Normalize the image with for CLIP input
224
+ image = self.feature_extractor(
225
+ images=image,
226
+ do_normalize=True,
227
+ do_center_crop=False,
228
+ do_resize=False,
229
+ do_rescale=False,
230
+ return_tensors="pt",
231
+ ).pixel_values
232
+
233
+ image = image.to(device=device, dtype=dtype)
234
+ image_embeddings = self.image_encoder(image).image_embeds
235
+ image_embeddings = image_embeddings.unsqueeze(1)
236
+
237
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
238
+ bs_embed, seq_len, _ = image_embeddings.shape
239
+ image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
240
+ image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
241
+
242
+ if do_classifier_free_guidance:
243
+ negative_image_embeddings = torch.zeros_like(image_embeddings)
244
+
245
+ # For classifier free guidance, we need to do two forward passes.
246
+ # Here we concatenate the unconditional and text embeddings into a single batch
247
+ # to avoid doing two forward passes
248
+ image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
249
+
250
+ return image_embeddings
251
+
252
+ def _encode_vae_image(
253
+ self,
254
+ image: torch.Tensor,
255
+ device: Union[str, torch.device],
256
+ num_videos_per_prompt: int,
257
+ do_classifier_free_guidance: bool,
258
+ ):
259
+ image = image.to(device=device)
260
+ image_latents = self.vae.encode(image).latent_dist.mode()
261
+
262
+ # duplicate image_latents for each generation per prompt, using mps friendly method
263
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
264
+
265
+ if do_classifier_free_guidance:
266
+ negative_image_latents = torch.zeros_like(image_latents)
267
+
268
+ # For classifier free guidance, we need to do two forward passes.
269
+ # Here we concatenate the unconditional and text embeddings into a single batch
270
+ # to avoid doing two forward passes
271
+ image_latents = torch.cat([negative_image_latents, image_latents])
272
+
273
+ return image_latents
274
+
275
+ def _get_add_time_ids(
276
+ self,
277
+ fps: int,
278
+ motion_bucket_id: int,
279
+ noise_aug_strength: float,
280
+ dtype: torch.dtype,
281
+ batch_size: int,
282
+ num_videos_per_prompt: int,
283
+ do_classifier_free_guidance: bool,
284
+ ):
285
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
286
+
287
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
288
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
289
+
290
+ if expected_add_embed_dim != passed_add_embed_dim:
291
+ raise ValueError(
292
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
293
+ )
294
+
295
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
296
+ add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
297
+
298
+ if do_classifier_free_guidance:
299
+ add_time_ids = torch.cat([add_time_ids, add_time_ids])
300
+
301
+ return add_time_ids
302
+
303
+ def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14):
304
+ # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
305
+ latents = latents.flatten(0, 1)
306
+
307
+ latents = 1 / self.vae.config.scaling_factor * latents
308
+
309
+ forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
310
+ accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
311
+
312
+ # decode decode_chunk_size frames at a time to avoid OOM
313
+ frames = []
314
+ for i in range(0, latents.shape[0], decode_chunk_size):
315
+ num_frames_in = latents[i : i + decode_chunk_size].shape[0]
316
+ decode_kwargs = {}
317
+ if accepts_num_frames:
318
+ # we only pass num_frames_in if it's expected
319
+ decode_kwargs["num_frames"] = num_frames_in
320
+
321
+ frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
322
+ frames.append(frame)
323
+ frames = torch.cat(frames, dim=0)
324
+
325
+ # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
326
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
327
+
328
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
329
+ frames = frames.float()
330
+ return frames
331
+
332
+ def check_inputs(self, image, height, width):
333
+ if (
334
+ not isinstance(image, torch.Tensor)
335
+ and not isinstance(image, PIL.Image.Image)
336
+ and not isinstance(image, list)
337
+ ):
338
+ raise ValueError(
339
+ "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
340
+ f" {type(image)}"
341
+ )
342
+
343
+ if height % 8 != 0 or width % 8 != 0:
344
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
345
+
346
+ def prepare_latents(
347
+ self,
348
+ batch_size: int,
349
+ num_frames: int,
350
+ num_channels_latents: int,
351
+ height: int,
352
+ width: int,
353
+ dtype: torch.dtype,
354
+ device: Union[str, torch.device],
355
+ generator: torch.Generator,
356
+ latents: Optional[torch.Tensor] = None,
357
+ ):
358
+ shape = (
359
+ batch_size,
360
+ num_frames,
361
+ num_channels_latents // 2,
362
+ height // self.vae_scale_factor,
363
+ width // self.vae_scale_factor,
364
+ )
365
+ if isinstance(generator, list) and len(generator) != batch_size:
366
+ raise ValueError(
367
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
368
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
369
+ )
370
+
371
+ if latents is None:
372
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
373
+ else:
374
+ latents = latents.to(device)
375
+
376
+ # scale the initial noise by the standard deviation required by the scheduler
377
+ latents = latents * self.scheduler.init_noise_sigma
378
+ return latents
379
+
380
+ @property
381
+ def guidance_scale(self):
382
+ return self._guidance_scale
383
+
384
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
385
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
386
+ # corresponds to doing no classifier free guidance.
387
+ @property
388
+ def do_classifier_free_guidance(self):
389
+ if isinstance(self.guidance_scale, (int, float)):
390
+ return self.guidance_scale > 0
391
+ return self.guidance_scale.max() > 0
392
+
393
+ @property
394
+ def num_timesteps(self):
395
+ return self._num_timesteps
396
+
397
+ #@torch.no_grad()
398
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
399
+ def __call__(
400
+ self,
401
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor],
402
+ height: int = 576,
403
+ width: int = 1024,
404
+ num_frames: Optional[int] = None,
405
+ num_inference_steps: int = 25,
406
+ sigmas: Optional[List[float]] = None,
407
+ min_guidance_scale: float = 1.0,
408
+ max_guidance_scale: float = 3.0,
409
+ fps: int = 7,
410
+ motion_bucket_id: int = 127,
411
+ noise_aug_strength: float = 0.02,
412
+ decode_chunk_size: Optional[int] = None,
413
+ num_videos_per_prompt: Optional[int] = 1,
414
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
415
+ latents: Optional[torch.Tensor] = None,
416
+ output_type: Optional[str] = "pil",
417
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
418
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
419
+ return_dict: bool = True,
420
+ focal_stack_num: int = None,
421
+ ):
422
+ r"""
423
+ The call function to the pipeline for generation.
424
+
425
+ Args:
426
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`):
427
+ Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0,
428
+ 1]`.
429
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
430
+ The height in pixels of the generated image.
431
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
432
+ The width in pixels of the generated image.
433
+ num_frames (`int`, *optional*):
434
+ The number of video frames to generate. Defaults to `self.unet.config.num_frames` (14 for
435
+ `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
436
+ num_inference_steps (`int`, *optional*, defaults to 25):
437
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
438
+ expense of slower inference. This parameter is modulated by `strength`.
439
+ sigmas (`List[float]`, *optional*):
440
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
441
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
442
+ will be used.
443
+ min_guidance_scale (`float`, *optional*, defaults to 1.0):
444
+ The minimum guidance scale. Used for the classifier free guidance with first frame.
445
+ max_guidance_scale (`float`, *optional*, defaults to 3.0):
446
+ The maximum guidance scale. Used for the classifier free guidance with last frame.
447
+ fps (`int`, *optional*, defaults to 7):
448
+ Frames per second. The rate at which the generated images shall be exported to a video after
449
+ generation. Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
450
+ motion_bucket_id (`int`, *optional*, defaults to 127):
451
+ Used for conditioning the amount of motion for the generation. The higher the number the more motion
452
+ will be in the video.
453
+ noise_aug_strength (`float`, *optional*, defaults to 0.02):
454
+ The amount of noise added to the init image, the higher it is the less the video will look like the
455
+ init image. Increase it for more motion.
456
+ decode_chunk_size (`int`, *optional*):
457
+ The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the
458
+ expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality.
459
+ For lower memory usage, reduce `decode_chunk_size`.
460
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
461
+ The number of videos to generate per prompt.
462
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
463
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
464
+ generation deterministic.
465
+ latents (`torch.Tensor`, *optional*):
466
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
467
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
468
+ tensor is generated by sampling using the supplied random `generator`.
469
+ output_type (`str`, *optional*, defaults to `"pil"`):
470
+ The output format of the generated image. Choose between `pil`, `np` or `pt`.
471
+ callback_on_step_end (`Callable`, *optional*):
472
+ A function that is called at the end of each denoising step during inference. The function is called
473
+ with the following arguments:
474
+ `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
475
+ `callback_kwargs` will include a list of all tensors as specified by
476
+ `callback_on_step_end_tensor_inputs`.
477
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
478
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
479
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
480
+ `._callback_tensor_inputs` attribute of your pipeline class.
481
+ return_dict (`bool`, *optional*, defaults to `True`):
482
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
483
+ plain tuple.
484
+
485
+ Examples:
486
+
487
+ Returns:
488
+ [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
489
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is
490
+ returned, otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.Tensor`) is
491
+ returned.
492
+ """
493
+
494
+ with torch.no_grad():
495
+
496
+ # 0. Default height and width to unet
497
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
498
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
499
+
500
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
501
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
502
+
503
+ # 1. Check inputs. Raise error if not correct
504
+ self.check_inputs(image, height, width)
505
+
506
+ # 2. Define call parameters
507
+ if isinstance(image, PIL.Image.Image):
508
+ batch_size = 1
509
+ elif isinstance(image, list):
510
+ batch_size = len(image)
511
+ else:
512
+ batch_size = image.shape[0]
513
+ device = self._execution_device
514
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
515
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
516
+ # corresponds to doing no classifier free guidance.
517
+ self._guidance_scale = max_guidance_scale
518
+
519
+
520
+
521
+ # NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here.
522
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
523
+ fps = fps - 1
524
+
525
+ # 4. Encode input image using VAE
526
+ # first_image = image[0, 0:1]
527
+ # first_image = self.video_processor.preprocess(first_image*0.5+0.5, height=height, width=width).to(device)
528
+ # noise = randn_tensor(first_image.shape, generator=generator, device=device, dtype=image.dtype)
529
+ # first_image = first_image + noise_aug_strength * noise #you add this noise to have a version of the image that the vae can denoise
530
+
531
+ # first_image = self.video_processor.preprocess(first_image*0.5+0.5, height=height, width=width).to(device)
532
+
533
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
534
+ if needs_upcasting:
535
+ self.vae.to(dtype=torch.float32)
536
+
537
+
538
+ image_latents = tensor_to_vae_latent(image, self.vae, otype="mode")/self.vae.config.scaling_factor
539
+ #noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=image.dtype)
540
+ #image_latents = image_latents + noise_aug_strength * noise #you add this noise to have a version of the image that the vae can denoise
541
+
542
+ # old_image_latents = self._encode_vae_image(
543
+ # first_image,
544
+ # device=device,
545
+ # num_videos_per_prompt=num_videos_per_prompt,
546
+ # do_classifier_free_guidance=self.do_classifier_free_guidance,
547
+ # )
548
+
549
+ if self.do_classifier_free_guidance:
550
+ negative_image_latents = torch.zeros_like(image_latents)
551
+
552
+ # For classifier free guidance, we need to do two forward passes.
553
+ # Here we concatenate the unconditional and text embeddings into a single batch
554
+ # to avoid doing two forward passes
555
+ image_latents = torch.cat([negative_image_latents, image_latents])
556
+
557
+ image_latents = image_latents.to(torch.float32)
558
+
559
+ # cast back to fp16 if needed
560
+ if needs_upcasting:
561
+ self.vae.to(dtype=torch.float16)
562
+
563
+ # Repeat the image latents for each frame so we can concatenate them with the noise
564
+ # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
565
+ #image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
566
+ mask = torch.zeros_like(image_latents)
567
+
568
+ if focal_stack_num is not None:
569
+ frame_idx = focal_stack_num
570
+ mask[:, frame_idx] = 1
571
+
572
+ original_image_latents = image_latents.clone()
573
+ image_latents = image_latents * mask
574
+
575
+ mask = mask == 1 #mask is a boolean tensor
576
+
577
+
578
+ clip_image = image[0, frame_idx: frame_idx+1]
579
+ resized_clip_image = _resize_with_antialiasing(clip_image, (224, 224))
580
+ image_embeddings = self._encode_image(resized_clip_image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
581
+
582
+ if motion_bucket_id is None: #this hits for ablation_time at validation time
583
+ motion_bucket_id = 0
584
+
585
+
586
+ # 5. Get Added Time IDs
587
+ added_time_ids = self._get_add_time_ids(
588
+ fps,
589
+ motion_bucket_id,
590
+ noise_aug_strength,
591
+ image_embeddings.dtype,
592
+ batch_size,
593
+ num_videos_per_prompt,
594
+ self.do_classifier_free_guidance,
595
+ )
596
+ added_time_ids = added_time_ids.to(device)
597
+
598
+ # 6. Prepare timesteps
599
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas)
600
+
601
+ # 7. Prepare latent variables
602
+ num_channels_latents = self.unet.config.in_channels
603
+ latents = self.prepare_latents(
604
+ batch_size * num_videos_per_prompt,
605
+ num_frames,
606
+ num_channels_latents,
607
+ height,
608
+ width,
609
+ image_embeddings.dtype,
610
+ device,
611
+ generator,
612
+ latents,
613
+ )
614
+
615
+ # 8. Prepare guidance scale
616
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
617
+ guidance_scale = guidance_scale.to(device, latents.dtype)
618
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
619
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
620
+
621
+ self._guidance_scale = guidance_scale
622
+
623
+ # 9. Denoising loop
624
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
625
+ self._num_timesteps = len(timesteps)
626
+
627
+
628
+ alphas_cumprod = 1 / (1 + self.scheduler.sigmas**2)
629
+ alphas = alphas_cumprod / torch.cat((torch.tensor([1.0]), alphas_cumprod[:-1]))
630
+
631
+
632
+ progress_bar = tqdm(range(num_inference_steps))
633
+ for i, t in enumerate(timesteps):
634
+ # expand the latents if we are doing classifier free guidance - this is because we have the unconditional and the conditional portion
635
+ #this is concatenation along the batch dimension
636
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
637
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
638
+
639
+ # Concatenate image_latents over channels dimension
640
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
641
+
642
+ # predict the noise residual
643
+ with torch.no_grad():
644
+ noise_pred_uncond = self.unet(
645
+ latent_model_input[0:1],
646
+ t,
647
+ encoder_hidden_states=image_embeddings[0:1],
648
+ added_time_ids=added_time_ids[0:1],
649
+ return_dict=False,
650
+ )[0]
651
+
652
+ noise_pred_cond = self.unet(
653
+ latent_model_input[1:2],
654
+ t,
655
+ encoder_hidden_states=image_embeddings[1:2],
656
+ added_time_ids=added_time_ids[1:2],
657
+ return_dict=False,
658
+ )[0]
659
+
660
+ with torch.no_grad():
661
+ noise_pred = torch.cat([noise_pred_uncond, noise_pred_cond])
662
+
663
+ # perform guidance
664
+ if self.do_classifier_free_guidance:
665
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
666
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
667
+ # compute the previous noisy sample x_t -> x_t-1
668
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
669
+
670
+
671
+ with torch.no_grad():
672
+ if callback_on_step_end is not None:
673
+ callback_kwargs = {}
674
+ for k in callback_on_step_end_tensor_inputs:
675
+ callback_kwargs[k] = locals()[k]
676
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
677
+ latents = callback_outputs.pop("latents", latents)
678
+
679
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
680
+ progress_bar.update()
681
+
682
+
683
+ with torch.no_grad():
684
+ if not output_type == "latent":
685
+ # cast back to fp16 if needed
686
+ if needs_upcasting:
687
+ self.vae.to(dtype=torch.float16)
688
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
689
+ gt = self.decode_latents(original_image_latents[1:2]*self.vae.config.scaling_factor, num_frames, decode_chunk_size)
690
+ else:
691
+ frames = latents
692
+
693
+ self.maybe_free_model_hooks()
694
+
695
+ if not return_dict:
696
+ return frames
697
+
698
+ return StableVideoDiffusionPipelineOutput(frames=frames), gt
699
+
700
+
701
+ # resizing utils
702
+ # TODO: clean up later
703
+ def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
704
+ h, w = input.shape[-2:]
705
+ factors = (h / size[0], w / size[1])
706
+
707
+ # First, we have to determine sigma
708
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
709
+ sigmas = (
710
+ max((factors[0] - 1.0) / 2.0, 0.001),
711
+ max((factors[1] - 1.0) / 2.0, 0.001),
712
+ )
713
+
714
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
715
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
716
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
717
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
718
+
719
+ # Make sure it is odd
720
+ if (ks[0] % 2) == 0:
721
+ ks = ks[0] + 1, ks[1]
722
+
723
+ if (ks[1] % 2) == 0:
724
+ ks = ks[0], ks[1] + 1
725
+
726
+ input = _gaussian_blur2d(input, ks, sigmas)
727
+
728
+ output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
729
+ return output
730
+
731
+
732
+ def _compute_padding(kernel_size):
733
+ """Compute padding tuple."""
734
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
735
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
736
+ if len(kernel_size) < 2:
737
+ raise AssertionError(kernel_size)
738
+ computed = [k - 1 for k in kernel_size]
739
+
740
+ # for even kernels we need to do asymmetric padding :(
741
+ out_padding = 2 * len(kernel_size) * [0]
742
+
743
+ for i in range(len(kernel_size)):
744
+ computed_tmp = computed[-(i + 1)]
745
+
746
+ pad_front = computed_tmp // 2
747
+ pad_rear = computed_tmp - pad_front
748
+
749
+ out_padding[2 * i + 0] = pad_front
750
+ out_padding[2 * i + 1] = pad_rear
751
+
752
+ return out_padding
753
+
754
+
755
+ def _filter2d(input, kernel):
756
+ # prepare kernel
757
+ b, c, h, w = input.shape
758
+ tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
759
+
760
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
761
+
762
+ height, width = tmp_kernel.shape[-2:]
763
+
764
+ padding_shape: List[int] = _compute_padding([height, width])
765
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
766
+
767
+ # kernel and input tensor reshape to align element-wise or batch-wise params
768
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
769
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
770
+
771
+ # convolve the tensor with the kernel.
772
+ output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
773
+
774
+ out = output.view(b, c, h, w)
775
+ return out
776
+
777
+
778
+ def _gaussian(window_size: int, sigma):
779
+ if isinstance(sigma, float):
780
+ sigma = torch.tensor([[sigma]])
781
+
782
+ batch_size = sigma.shape[0]
783
+
784
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
785
+
786
+ if window_size % 2 == 0:
787
+ x = x + 0.5
788
+
789
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
790
+
791
+ return gauss / gauss.sum(-1, keepdim=True)
792
+
793
+
794
+ def _gaussian_blur2d(input, kernel_size, sigma):
795
+ if isinstance(sigma, tuple):
796
+ sigma = torch.tensor([sigma], dtype=input.dtype)
797
+ else:
798
+ sigma = sigma.to(dtype=input.dtype)
799
+
800
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
801
+ bs = sigma.shape[0]
802
+ kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
803
+ kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
804
+ out_x = _filter2d(input, kernel_x[..., None, :])
805
+ out = _filter2d(out_x, kernel_y[..., None])
806
+
807
+ return out
simplified_validation.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from simplified_pipeline import StableVideoDiffusionPipeline
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import videoio
6
+ import matplotlib.image
7
+ from PIL import Image
8
+
9
+
10
+
11
+ def valid_net(args, batch, unet, image_encoder, vae, global_step, weight_dtype, device):
12
+
13
+ # The models need unwrapping because for compatibility in distributed training mode.
14
+
15
+ pipeline = StableVideoDiffusionPipeline.from_pretrained(
16
+ args.pretrained_model_path,
17
+ unet=unet,
18
+ image_encoder=image_encoder,
19
+ vae=vae,
20
+ torch_dtype=weight_dtype,
21
+ )
22
+
23
+ pipeline.set_progress_bar_config(disable=True)
24
+
25
+ # run inference
26
+ val_save_dir = os.path.join(
27
+ args.output_dir, "validation_images")
28
+
29
+ print("Validation images will be saved to ", val_save_dir)
30
+
31
+ os.makedirs(val_save_dir, exist_ok=True)
32
+
33
+
34
+ num_frames = 9
35
+ unet.eval()
36
+
37
+ #clear gradients (the torch no grad is the magic that makes this work)
38
+ with torch.no_grad():
39
+ torch.cuda.empty_cache()
40
+
41
+ pixel_values = batch["pixel_values"].to(device)
42
+ original_pixel_values = batch['original_pixel_values'].to(device)
43
+ focal_stack_num = batch["focal_stack_num"]
44
+
45
+ svd_output, gt_frames = pipeline(
46
+ pixel_values,
47
+ height=pixel_values.shape[3],
48
+ width=pixel_values.shape[4],
49
+ num_frames=num_frames,
50
+ decode_chunk_size=8,
51
+ motion_bucket_id=0,
52
+ min_guidance_scale=1.5,
53
+ max_guidance_scale=1.5,
54
+ fps=7,
55
+ noise_aug_strength=0,
56
+ focal_stack_num = focal_stack_num,
57
+ num_inference_steps=args.num_inference_steps,
58
+ )
59
+ video_frames = svd_output.frames[0]
60
+ gt_frames = gt_frames[0]
61
+
62
+
63
+ with torch.no_grad():
64
+
65
+ if len(original_pixel_values.shape) == 5:
66
+ pixel_values = original_pixel_values[0] #assuming batch size is 1
67
+ else:
68
+ pixel_values = original_pixel_values.repeat(num_frames, 1, 1, 1)
69
+ pixel_values_normalized = pixel_values*0.5 + 0.5
70
+ pixel_values_normalized = torch.clamp(pixel_values_normalized,0,1)
71
+
72
+
73
+
74
+
75
+ video_frames_normalized = video_frames*0.5 + 0.5
76
+ video_frames_normalized = torch.clamp(video_frames_normalized,0,1)
77
+ video_frames_normalized = video_frames_normalized.permute(1,0,2,3)
78
+
79
+
80
+ gt_frames = torch.clamp(gt_frames,0,1)
81
+ gt_frames = gt_frames.permute(1,0,2,3)
82
+
83
+ #RESIZE images
84
+ video_frames_normalized = torch.nn.functional.interpolate(video_frames_normalized, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
85
+ gt_frames = torch.nn.functional.interpolate(gt_frames, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
86
+ pixel_values_normalized = torch.nn.functional.interpolate(pixel_values_normalized, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
87
+
88
+ os.makedirs(os.path.join(val_save_dir, f"position_{focal_stack_num}/videos"), exist_ok=True)
89
+ videoio.videosave(os.path.join(
90
+ val_save_dir,
91
+ f"position_{focal_stack_num}/videos/{batch['name']}.mp4",
92
+ ), video_frames_normalized.permute(0,2,3,1).cpu().numpy(), fps=5)
93
+
94
+ #save images
95
+ os.makedirs(os.path.join(val_save_dir, f"position_{focal_stack_num}/images"), exist_ok=True)
96
+ for i in range(num_frames):
97
+ #use Pillow to save images
98
+ img = Image.fromarray((video_frames_normalized[i].permute(1,2,0).cpu().numpy()*255).astype(np.uint8))
99
+ #use index to assign icc profile to img
100
+ if batch['icc_profile'] != "none":
101
+ img.info['icc_profile'] = batch['icc_profile']
102
+ path = os.path.join(val_save_dir, f"position_{focal_stack_num}/images/{batch['name']}_frame_{i}.png")
103
+ print("Saving image to ", path)
104
+ img.save(os.path.join(val_save_dir, f"position_{focal_stack_num}/images/{batch['name']}_frame_{i}.png"))
105
+ del video_frames
106
+
107
+
108
+
splits/test_scenes.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b7d6ac77b97cf4b5fa62ffa13df88fc6dec2dfe4d5fbc981b79373c4766b86a
3
+ size 4936
splits/train_scenes.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec9f60c6cee001b10f0f0928f24d4fafc54ec6c3d9ed1e34069b3c0da0e8e570
3
+ size 44238
training/configs/accelerator_config.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ main_process_port: 29501
3
+ debug: false
4
+ deepspeed_config:
5
+ gradient_accumulation_steps: 1
6
+ gradient_clipping: 1.0
7
+ offload_optimizer_device: none
8
+ offload_param_device: none
9
+ zero3_init_flag: false
10
+ zero_stage: 2
11
+ distributed_type: DEEPSPEED
12
+ downcast_bf16: 'no'
13
+ enable_cpu_affinity: false
14
+ machine_rank: 0
15
+ main_training_function: main
16
+ dynamo_backend: 'no'
17
+ mixed_precision: 'no'
18
+ num_machines: 1
19
+ num_processes: 4
20
+ rdzv_backend: static
21
+ same_network: true
22
+ tpu_env: []
23
+ tpu_use_cluster: false
24
+ tpu_use_sudo: false
25
+ use_cpu: false
training/configs/focal_stacks_test.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_folder: "/datasets/sai/scenes_merged"
2
+ splits_dir: "./splits" #all split.pkl files are stored here
3
+ pretrained_model_name_or_path: "./svdh"
4
+ load_from_checkpoint: "./checkpoints/checkpoint-200000"
5
+ output_dir: "./outputs/focal_stacks_test"
6
+ wandb_project: "RefocusingSVD"
7
+ run_name: "focal_stacks_test"
8
+ test: true
9
+ revision: null
10
+ num_frames: 9
11
+ num_validation_images: 1
12
+ validation_steps: 1000
13
+ photos: false
14
+ conditioning: "random"
15
+ seed: 0
16
+ per_gpu_batch_size: 1
17
+ num_train_epochs: 600
18
+ max_train_steps: null
19
+ gradient_accumulation_steps: 1
20
+ gradient_checkpointing: false
21
+ learning_rate: 0.00001
22
+ reconstruction_guidance: 0
23
+ scale_lr: true
24
+ lr_scheduler: "constant"
25
+ lr_warmup_steps: 0
26
+ conditioning_dropout_prob: 0.1
27
+ use_8bit_adam: false
28
+ allow_tf32: false
29
+ use_ema: false
30
+ non_ema_revision: null
31
+ num_workers: 32
32
+ adam_beta1: 0.9
33
+ adam_beta2: 0.999
34
+ adam_weight_decay: 0.01
35
+ adam_epsilon: 0.0000001
36
+ max_grad_norm: 1.0
37
+ push_to_hub: false
38
+ hub_token: null
39
+ hub_model_id: null
40
+ logging_dir: "logs"
41
+ mixed_precision: null
42
+ report_to: "wandb"
43
+ local_rank: -1
44
+ checkpointing_steps: 500
45
+ checkpoints_total_limit: 2
46
+ enable_xformers_memory_efficient_attention: false
47
+ pretrain_unet: null
training/configs/focal_stacks_train.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_folder: "/datasets/sai/scenes_merged"
2
+ splits_dir: "./splits" #all split.pkl files are stored here
3
+ pretrained_model_name_or_path: "./svdh"
4
+ load_from_checkpoint: null
5
+ output_dir: "./outputs/focal_stacks_train"
6
+ wandb_project: "RefocusingSVD"
7
+ run_name: "focal_stacks_train"
8
+ test: false
9
+ revision: null
10
+ num_frames: 9
11
+ num_validation_images: 1
12
+ validation_steps: 1000
13
+ photos: false
14
+ conditioning: "random"
15
+ seed: 0
16
+ per_gpu_batch_size: 1
17
+ num_train_epochs: 600
18
+ max_train_steps: null
19
+ gradient_accumulation_steps: 1
20
+ gradient_checkpointing: false
21
+ learning_rate: 0.00001
22
+ reconstruction_guidance: 0
23
+ scale_lr: true
24
+ lr_scheduler: "constant"
25
+ lr_warmup_steps: 0
26
+ conditioning_dropout_prob: 0.1
27
+ use_8bit_adam: false
28
+ allow_tf32: false
29
+ use_ema: false
30
+ non_ema_revision: null
31
+ num_workers: 32
32
+ adam_beta1: 0.9
33
+ adam_beta2: 0.999
34
+ adam_weight_decay: 0.01
35
+ adam_epsilon: 0.0000001
36
+ max_grad_norm: 1.0
37
+ push_to_hub: false
38
+ hub_token: null
39
+ hub_model_id: null
40
+ logging_dir: "logs"
41
+ mixed_precision: null
42
+ report_to: "wandb"
43
+ local_rank: -1
44
+ checkpointing_steps: 500
45
+ checkpoints_total_limit: 2
46
+ enable_xformers_memory_efficient_attention: false
47
+ pretrain_unet: null
training/configs/outside_photos.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ photos: true # Use outside photos
2
+ data_folder: "./photos"
3
+ pretrained_model_name_or_path: "./svdh"
4
+ load_from_checkpoint: "./checkpoints/checkpoint-200000"
5
+ output_dir: "./outputs/outside_photos"
6
+ wandb_project: "RefocusingSVD"
7
+ run_name: "outside_photos"
8
+ test: true
9
+ revision: null
10
+ num_frames: 9
11
+ num_validation_images: 1
12
+ validation_steps: 1000
13
+ conditioning: "random"
14
+ seed: 0
15
+ per_gpu_batch_size: 1
16
+ num_train_epochs: 600
17
+ max_train_steps: null
18
+ gradient_accumulation_steps: 1
19
+ gradient_checkpointing: false
20
+ learning_rate: 0.00001
21
+ reconstruction_guidance: 0
22
+ scale_lr: true
23
+ lr_scheduler: "constant"
24
+ lr_warmup_steps: 0
25
+ conditioning_dropout_prob: 0.1
26
+ use_8bit_adam: false
27
+ allow_tf32: false
28
+ use_ema: false
29
+ non_ema_revision: null
30
+ num_workers: 32
31
+ adam_beta1: 0.9
32
+ adam_beta2: 0.999
33
+ adam_weight_decay: 0.01
34
+ adam_epsilon: 0.0000001
35
+ max_grad_norm: 1.0
36
+ push_to_hub: false
37
+ hub_token: null
38
+ hub_model_id: null
39
+ logging_dir: "logs"
40
+ mixed_precision: null
41
+ report_to: "wandb"
42
+ local_rank: -1
43
+ checkpointing_steps: 500
44
+ checkpoints_total_limit: 2
45
+ enable_xformers_memory_efficient_attention: false
46
+ pretrain_unet: null
training/svd_pipeline.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from dataclasses import dataclass
17
+ import random
18
+ from typing import Callable, Dict, List, Optional, Union
19
+
20
+ import numpy as np
21
+ import PIL.Image
22
+ import torch
23
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
24
+
25
+ from diffusers.image_processor import PipelineImageInput
26
+ from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
27
+ from diffusers.schedulers import EulerDiscreteScheduler
28
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
29
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
30
+ from diffusers.video_processor import VideoProcessor
31
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
32
+ import torch.nn.functional as F
33
+ from tqdm import tqdm
34
+ from utils import tensor_to_vae_latent
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+ EXAMPLE_DOC_STRING = """
40
+ Examples:
41
+ ```py
42
+ >>> from diffusers import StableVideoDiffusionPipeline
43
+ >>> from diffusers.utils import load_image, export_to_video
44
+
45
+ >>> pipe = StableVideoDiffusionPipeline.from_pretrained(
46
+ ... "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
47
+ ... )
48
+ >>> pipe.to("cuda")
49
+
50
+ >>> image = load_image(
51
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd-docstring-example.jpeg"
52
+ ... )
53
+ >>> image = image.resize((1024, 576))
54
+
55
+ >>> frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
56
+ >>> export_to_video(frames, "generated.mp4", fps=7)
57
+ ```
58
+ """
59
+
60
+
61
+ def _append_dims(x, target_dims):
62
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
63
+ dims_to_append = target_dims - x.ndim
64
+ if dims_to_append < 0:
65
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
66
+ return x[(...,) + (None,) * dims_to_append]
67
+
68
+
69
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
70
+ def retrieve_timesteps(
71
+ scheduler,
72
+ num_inference_steps: Optional[int] = None,
73
+ device: Optional[Union[str, torch.device]] = None,
74
+ timesteps: Optional[List[int]] = None,
75
+ sigmas: Optional[List[float]] = None,
76
+ **kwargs,
77
+ ):
78
+ r"""
79
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
80
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
81
+
82
+ Args:
83
+ scheduler (`SchedulerMixin`):
84
+ The scheduler to get timesteps from.
85
+ num_inference_steps (`int`):
86
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
87
+ must be `None`.
88
+ device (`str` or `torch.device`, *optional*):
89
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
90
+ timesteps (`List[int]`, *optional*):
91
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
92
+ `num_inference_steps` and `sigmas` must be `None`.
93
+ sigmas (`List[float]`, *optional*):
94
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
95
+ `num_inference_steps` and `timesteps` must be `None`.
96
+
97
+ Returns:
98
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
99
+ second element is the number of inference steps.
100
+ """
101
+ if timesteps is not None and sigmas is not None:
102
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
103
+ if timesteps is not None:
104
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
105
+ if not accepts_timesteps:
106
+ raise ValueError(
107
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
108
+ f" timestep schedules. Please check whether you are using the correct scheduler."
109
+ )
110
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
111
+ timesteps = scheduler.timesteps
112
+ num_inference_steps = len(timesteps)
113
+ elif sigmas is not None:
114
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
115
+ if not accept_sigmas:
116
+ raise ValueError(
117
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
118
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
119
+ )
120
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
121
+ timesteps = scheduler.timesteps
122
+ num_inference_steps = len(timesteps)
123
+ else:
124
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
125
+ timesteps = scheduler.timesteps
126
+ return timesteps, num_inference_steps
127
+
128
+
129
+ @dataclass
130
+ class StableVideoDiffusionPipelineOutput(BaseOutput):
131
+ r"""
132
+ Output class for Stable Video Diffusion pipeline.
133
+
134
+ Args:
135
+ frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]):
136
+ List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
137
+ num_frames, height, width, num_channels)`.
138
+ """
139
+
140
+ frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor]
141
+
142
+
143
+ class StableVideoDiffusionPipeline(DiffusionPipeline):
144
+ r"""
145
+ Pipeline to generate video from an input image using Stable Video Diffusion.
146
+
147
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
148
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
149
+
150
+ Args:
151
+ vae ([`AutoencoderKLTemporalDecoder`]):
152
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
153
+ image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
154
+ Frozen CLIP image-encoder
155
+ ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
156
+ unet ([`UNetSpatioTemporalConditionModel`]):
157
+ A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
158
+ scheduler ([`EulerDiscreteScheduler`]):
159
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
160
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
161
+ A `CLIPImageProcessor` to extract features from generated images.
162
+ """
163
+
164
+ model_cpu_offload_seq = "image_encoder->unet->vae"
165
+ _callback_tensor_inputs = ["latents"]
166
+
167
+ def __init__(
168
+ self,
169
+ vae: AutoencoderKLTemporalDecoder,
170
+ image_encoder: CLIPVisionModelWithProjection,
171
+ unet: UNetSpatioTemporalConditionModel,
172
+ scheduler: EulerDiscreteScheduler,
173
+ feature_extractor: CLIPImageProcessor,
174
+ ):
175
+ super().__init__()
176
+
177
+ self.register_modules(
178
+ vae=vae,
179
+ image_encoder=image_encoder,
180
+ unet=unet,
181
+ scheduler=scheduler,
182
+ feature_extractor=feature_extractor,
183
+ )
184
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
185
+ self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor)
186
+
187
+
188
+
189
+ def _encode_image(
190
+ self,
191
+ image: PipelineImageInput,
192
+ device: Union[str, torch.device],
193
+ num_videos_per_prompt: int,
194
+ do_classifier_free_guidance: bool,
195
+ ) -> torch.Tensor:
196
+ dtype = next(self.image_encoder.parameters()).dtype
197
+
198
+ if not isinstance(image, torch.Tensor):
199
+ image = self.video_processor.pil_to_numpy(image)
200
+ image = self.video_processor.numpy_to_pt(image)
201
+
202
+ # We normalize the image before resizing to match with the original implementation.
203
+ # Then we unnormalize it after resizing.
204
+ image = image * 2.0 - 1.0
205
+ image = _resize_with_antialiasing(image, (224, 224))
206
+ image = (image + 1.0) / 2.0
207
+
208
+
209
+ # Normalize the image with for CLIP input
210
+ image = self.feature_extractor(
211
+ images=image,
212
+ do_normalize=True,
213
+ do_center_crop=False,
214
+ do_resize=False,
215
+ do_rescale=False,
216
+ return_tensors="pt",
217
+ ).pixel_values
218
+
219
+ image = image.to(device=device, dtype=dtype)
220
+ image_embeddings = self.image_encoder(image).image_embeds
221
+ image_embeddings = image_embeddings.unsqueeze(1)
222
+
223
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
224
+ bs_embed, seq_len, _ = image_embeddings.shape
225
+ image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
226
+ image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
227
+
228
+ if do_classifier_free_guidance:
229
+ negative_image_embeddings = torch.zeros_like(image_embeddings)
230
+
231
+ # For classifier free guidance, we need to do two forward passes.
232
+ # Here we concatenate the unconditional and text embeddings into a single batch
233
+ # to avoid doing two forward passes
234
+ image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
235
+
236
+ return image_embeddings
237
+
238
+ def _encode_vae_image(
239
+ self,
240
+ image: torch.Tensor,
241
+ device: Union[str, torch.device],
242
+ num_videos_per_prompt: int,
243
+ do_classifier_free_guidance: bool,
244
+ ):
245
+ image = image.to(device=device)
246
+ image_latents = self.vae.encode(image).latent_dist.mode()
247
+
248
+ # duplicate image_latents for each generation per prompt, using mps friendly method
249
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
250
+
251
+ if do_classifier_free_guidance:
252
+ negative_image_latents = torch.zeros_like(image_latents)
253
+
254
+ # For classifier free guidance, we need to do two forward passes.
255
+ # Here we concatenate the unconditional and text embeddings into a single batch
256
+ # to avoid doing two forward passes
257
+ image_latents = torch.cat([negative_image_latents, image_latents])
258
+
259
+ return image_latents
260
+
261
+ def _get_add_time_ids(
262
+ self,
263
+ fps: int,
264
+ motion_bucket_id: int,
265
+ noise_aug_strength: float,
266
+ dtype: torch.dtype,
267
+ batch_size: int,
268
+ num_videos_per_prompt: int,
269
+ do_classifier_free_guidance: bool,
270
+ ):
271
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
272
+
273
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
274
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
275
+
276
+ if expected_add_embed_dim != passed_add_embed_dim:
277
+ raise ValueError(
278
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
279
+ )
280
+
281
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
282
+ add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
283
+
284
+ if do_classifier_free_guidance:
285
+ add_time_ids = torch.cat([add_time_ids, add_time_ids])
286
+
287
+ return add_time_ids
288
+
289
+ def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14):
290
+ # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
291
+ latents = latents.flatten(0, 1)
292
+
293
+ latents = 1 / self.vae.config.scaling_factor * latents
294
+
295
+ forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
296
+ accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
297
+
298
+ # decode decode_chunk_size frames at a time to avoid OOM
299
+ frames = []
300
+ for i in range(0, latents.shape[0], decode_chunk_size):
301
+ num_frames_in = latents[i : i + decode_chunk_size].shape[0]
302
+ decode_kwargs = {}
303
+ if accepts_num_frames:
304
+ # we only pass num_frames_in if it's expected
305
+ decode_kwargs["num_frames"] = num_frames_in
306
+
307
+ frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
308
+ frames.append(frame)
309
+ frames = torch.cat(frames, dim=0)
310
+
311
+ # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
312
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
313
+
314
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
315
+ frames = frames.float()
316
+ return frames
317
+
318
+ def check_inputs(self, image, height, width):
319
+ if (
320
+ not isinstance(image, torch.Tensor)
321
+ and not isinstance(image, PIL.Image.Image)
322
+ and not isinstance(image, list)
323
+ ):
324
+ raise ValueError(
325
+ "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
326
+ f" {type(image)}"
327
+ )
328
+
329
+ if height % 8 != 0 or width % 8 != 0:
330
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
331
+
332
+ def prepare_latents(
333
+ self,
334
+ batch_size: int,
335
+ num_frames: int,
336
+ num_channels_latents: int,
337
+ height: int,
338
+ width: int,
339
+ dtype: torch.dtype,
340
+ device: Union[str, torch.device],
341
+ generator: torch.Generator,
342
+ latents: Optional[torch.Tensor] = None,
343
+ ):
344
+ shape = (
345
+ batch_size,
346
+ num_frames,
347
+ num_channels_latents // 2,
348
+ height // self.vae_scale_factor,
349
+ width // self.vae_scale_factor,
350
+ )
351
+ if isinstance(generator, list) and len(generator) != batch_size:
352
+ raise ValueError(
353
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
354
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
355
+ )
356
+
357
+ if latents is None:
358
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
359
+ else:
360
+ latents = latents.to(device)
361
+
362
+ # scale the initial noise by the standard deviation required by the scheduler
363
+ latents = latents * self.scheduler.init_noise_sigma
364
+ return latents
365
+
366
+ @property
367
+ def guidance_scale(self):
368
+ return self._guidance_scale
369
+
370
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
371
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
372
+ # corresponds to doing no classifier free guidance.
373
+ @property
374
+ def do_classifier_free_guidance(self):
375
+ if isinstance(self.guidance_scale, (int, float)):
376
+ return self.guidance_scale > 0
377
+ return self.guidance_scale.max() > 0
378
+
379
+ @property
380
+ def num_timesteps(self):
381
+ return self._num_timesteps
382
+
383
+ #@torch.no_grad()
384
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
385
+ def __call__(
386
+ self,
387
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor],
388
+ height: int = 576,
389
+ width: int = 1024,
390
+ num_frames: Optional[int] = None,
391
+ num_inference_steps: int = 25,
392
+ sigmas: Optional[List[float]] = None,
393
+ min_guidance_scale: float = 1.0,
394
+ max_guidance_scale: float = 3.0,
395
+ reconstruction_guidance_scale: float = 2.0,
396
+ fps: int = 7,
397
+ motion_bucket_id: int = 127,
398
+ noise_aug_strength: float = 0.02,
399
+ decode_chunk_size: Optional[int] = None,
400
+ num_videos_per_prompt: Optional[int] = 1,
401
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
402
+ latents: Optional[torch.Tensor] = None,
403
+ output_type: Optional[str] = "pil",
404
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
405
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
406
+ return_dict: bool = True,
407
+ conditioning: str = "zero",
408
+ focal_stack_num: int = None,
409
+ accelerator=None,
410
+ weight_dtype=None,
411
+ zero=0
412
+ ):
413
+ r"""
414
+ The call function to the pipeline for generation.
415
+
416
+ Args:
417
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`):
418
+ Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0,
419
+ 1]`.
420
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
421
+ The height in pixels of the generated image.
422
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
423
+ The width in pixels of the generated image.
424
+ num_frames (`int`, *optional*):
425
+ The number of video frames to generate. Defaults to `self.unet.config.num_frames` (14 for
426
+ `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
427
+ num_inference_steps (`int`, *optional*, defaults to 25):
428
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
429
+ expense of slower inference. This parameter is modulated by `strength`.
430
+ sigmas (`List[float]`, *optional*):
431
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
432
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
433
+ will be used.
434
+ min_guidance_scale (`float`, *optional*, defaults to 1.0):
435
+ The minimum guidance scale. Used for the classifier free guidance with first frame.
436
+ max_guidance_scale (`float`, *optional*, defaults to 3.0):
437
+ The maximum guidance scale. Used for the classifier free guidance with last frame.
438
+ fps (`int`, *optional*, defaults to 7):
439
+ Frames per second. The rate at which the generated images shall be exported to a video after
440
+ generation. Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
441
+ motion_bucket_id (`int`, *optional*, defaults to 127):
442
+ Used for conditioning the amount of motion for the generation. The higher the number the more motion
443
+ will be in the video.
444
+ noise_aug_strength (`float`, *optional*, defaults to 0.02):
445
+ The amount of noise added to the init image, the higher it is the less the video will look like the
446
+ init image. Increase it for more motion.
447
+ decode_chunk_size (`int`, *optional*):
448
+ The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the
449
+ expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality.
450
+ For lower memory usage, reduce `decode_chunk_size`.
451
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
452
+ The number of videos to generate per prompt.
453
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
454
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
455
+ generation deterministic.
456
+ latents (`torch.Tensor`, *optional*):
457
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
458
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
459
+ tensor is generated by sampling using the supplied random `generator`.
460
+ output_type (`str`, *optional*, defaults to `"pil"`):
461
+ The output format of the generated image. Choose between `pil`, `np` or `pt`.
462
+ callback_on_step_end (`Callable`, *optional*):
463
+ A function that is called at the end of each denoising step during inference. The function is called
464
+ with the following arguments:
465
+ `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
466
+ `callback_kwargs` will include a list of all tensors as specified by
467
+ `callback_on_step_end_tensor_inputs`.
468
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
469
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
470
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
471
+ `._callback_tensor_inputs` attribute of your pipeline class.
472
+ return_dict (`bool`, *optional*, defaults to `True`):
473
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
474
+ plain tuple.
475
+
476
+ Examples:
477
+
478
+ Returns:
479
+ [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
480
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is
481
+ returned, otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.Tensor`) is
482
+ returned.
483
+ """
484
+
485
+ with torch.no_grad():
486
+
487
+ # 0. Default height and width to unet
488
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
489
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
490
+
491
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
492
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
493
+
494
+ # 1. Check inputs. Raise error if not correct
495
+ self.check_inputs(image, height, width)
496
+
497
+ # 2. Define call parameters
498
+ if isinstance(image, PIL.Image.Image):
499
+ batch_size = 1
500
+ elif isinstance(image, list):
501
+ batch_size = len(image)
502
+ else:
503
+ batch_size = image.shape[0]
504
+ device = self._execution_device
505
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
506
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
507
+ # corresponds to doing no classifier free guidance.
508
+ self._guidance_scale = max_guidance_scale
509
+
510
+
511
+
512
+ # NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here.
513
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
514
+ fps = fps - 1
515
+
516
+ # 4. Encode input image using VAE
517
+ # first_image = image[0, 0:1]
518
+ # first_image = self.video_processor.preprocess(first_image*0.5+0.5, height=height, width=width).to(device)
519
+ # noise = randn_tensor(first_image.shape, generator=generator, device=device, dtype=image.dtype)
520
+ # first_image = first_image + noise_aug_strength * noise #you add this noise to have a version of the image that the vae can denoise
521
+
522
+ # first_image = self.video_processor.preprocess(first_image*0.5+0.5, height=height, width=width).to(device)
523
+
524
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
525
+ if needs_upcasting:
526
+ self.vae.to(dtype=torch.float32)
527
+
528
+
529
+ image_latents = tensor_to_vae_latent(image, self.vae, otype="mode")/self.vae.config.scaling_factor
530
+ #noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=image.dtype)
531
+ #image_latents = image_latents + noise_aug_strength * noise #you add this noise to have a version of the image that the vae can denoise
532
+
533
+ # old_image_latents = self._encode_vae_image(
534
+ # first_image,
535
+ # device=device,
536
+ # num_videos_per_prompt=num_videos_per_prompt,
537
+ # do_classifier_free_guidance=self.do_classifier_free_guidance,
538
+ # )
539
+
540
+ if self.do_classifier_free_guidance:
541
+ negative_image_latents = torch.zeros_like(image_latents)
542
+
543
+ # For classifier free guidance, we need to do two forward passes.
544
+ # Here we concatenate the unconditional and text embeddings into a single batch
545
+ # to avoid doing two forward passes
546
+ image_latents = torch.cat([negative_image_latents, image_latents])
547
+
548
+ image_latents = image_latents.to(torch.float32)
549
+
550
+ # cast back to fp16 if needed
551
+ if needs_upcasting:
552
+ self.vae.to(dtype=torch.float16)
553
+
554
+ # Repeat the image latents for each frame so we can concatenate them with the noise
555
+ # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
556
+ #image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
557
+ mask = torch.zeros_like(image_latents)
558
+
559
+ if focal_stack_num is not None:
560
+ frame_idx = focal_stack_num
561
+ mask[:, frame_idx] = 1
562
+ elif conditioning == "zero":
563
+ frame_idx = 0
564
+ mask[:, 0] = 1
565
+ elif conditioning == "random":
566
+ rand_idx = np.random.randint(0, num_frames) #randomly choose a frame to condition on between 0 and 8 (inclusive)
567
+ frame_idx = rand_idx
568
+ mask[:, rand_idx] = 1
569
+ elif conditioning in ["ablate_position", "ablate_time"]:
570
+ frame_idx = 0 #zero for simple testing (this won't be hit at testing time)
571
+ elif conditioning == "five":
572
+ frame_idx = 4
573
+ mask[:, 4] = 1
574
+
575
+ original_image_latents = image_latents.clone()
576
+ if conditioning in ["ablate_position", "ablate_time"]:
577
+ image_latents = image_latents[:, frame_idx:frame_idx+1].repeat(1,num_frames, 1, 1, 1)
578
+ else:
579
+ image_latents = image_latents * mask
580
+
581
+ mask = mask == 1 #mask is a boolean tensor
582
+
583
+
584
+ clip_image = image[0, frame_idx: frame_idx+1]
585
+ resized_clip_image = _resize_with_antialiasing(clip_image, (224, 224))
586
+ image_embeddings = self._encode_image(resized_clip_image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
587
+
588
+ if motion_bucket_id is None: #this hits for ablation_time at validation time
589
+ motion_bucket_id = 0
590
+
591
+
592
+ # 5. Get Added Time IDs
593
+ added_time_ids = self._get_add_time_ids(
594
+ fps,
595
+ motion_bucket_id,
596
+ noise_aug_strength,
597
+ image_embeddings.dtype,
598
+ batch_size,
599
+ num_videos_per_prompt,
600
+ self.do_classifier_free_guidance,
601
+ )
602
+ added_time_ids = added_time_ids.to(device)
603
+
604
+ # 6. Prepare timesteps
605
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas)
606
+
607
+ # 7. Prepare latent variables
608
+ num_channels_latents = self.unet.config.in_channels
609
+ latents = self.prepare_latents(
610
+ batch_size * num_videos_per_prompt,
611
+ num_frames,
612
+ num_channels_latents,
613
+ height,
614
+ width,
615
+ image_embeddings.dtype,
616
+ device,
617
+ generator,
618
+ latents,
619
+ )
620
+
621
+ # 8. Prepare guidance scale
622
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
623
+ guidance_scale = guidance_scale.to(device, latents.dtype)
624
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
625
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
626
+
627
+ self._guidance_scale = guidance_scale
628
+
629
+ # 9. Denoising loop
630
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
631
+ self._num_timesteps = len(timesteps)
632
+
633
+
634
+ alphas_cumprod = 1 / (1 + self.scheduler.sigmas**2)
635
+ alphas = alphas_cumprod / torch.cat((torch.tensor([1.0]), alphas_cumprod[:-1]))
636
+
637
+
638
+ progress_bar = tqdm(range(num_inference_steps), disable=not accelerator.is_local_main_process)
639
+ for i, t in enumerate(timesteps):
640
+ # expand the latents if we are doing classifier free guidance - this is because we have the unconditional and the conditional portion
641
+ #this is concatenation along the batch dimension
642
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
643
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
644
+
645
+ # Concatenate image_latents over channels dimension
646
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
647
+
648
+ # predict the noise residual
649
+ with torch.no_grad():
650
+ noise_pred_uncond = self.unet(
651
+ latent_model_input[0:1],
652
+ t,
653
+ encoder_hidden_states=image_embeddings[0:1],
654
+ added_time_ids=added_time_ids[0:1],
655
+ return_dict=False,
656
+ )[0]
657
+
658
+ noise_pred_cond = self.unet(
659
+ latent_model_input[1:2],
660
+ t,
661
+ encoder_hidden_states=image_embeddings[1:2],
662
+ added_time_ids=added_time_ids[1:2],
663
+ return_dict=False,
664
+ )[0]
665
+
666
+ with torch.no_grad():
667
+ noise_pred = torch.cat([noise_pred_uncond, noise_pred_cond])
668
+
669
+ # perform guidance
670
+ if self.do_classifier_free_guidance:
671
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
672
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
673
+ # compute the previous noisy sample x_t -> x_t-1
674
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
675
+
676
+
677
+ if self.scheduler._step_index < len(timesteps) and reconstruction_guidance_scale > 0:
678
+ noise_pred = self.unet(
679
+ torch.cat([latents, image_latents[1:2]], dim=2),
680
+ t,
681
+ encoder_hidden_states=image_embeddings[1:2],
682
+ added_time_ids=added_time_ids[1:2],
683
+ return_dict=False,
684
+ )[0]
685
+ reconstructed_latent_cond = self.scheduler.step(noise_pred, t, latents).pred_original_sample #x_0 - given the noise
686
+ self.scheduler._step_index-=1 #remove the step
687
+ reconstruction_loss = F.mse_loss((image_latents[1, mask[1]]).to(torch.float32)*self.vae.config.scaling_factor, reconstructed_latent_cond[mask[1:2]], reduction="mean") #Squared L2 loss
688
+ reconstruction_grad = torch.autograd.grad(reconstruction_loss, reconstructed_latent_cond, retain_graph=True)[0]
689
+ accelerator.backward(reconstruction_loss)
690
+ latents = latents - reconstruction_guidance_scale*alphas[self.scheduler.step_index]*reconstruction_grad
691
+
692
+ with torch.no_grad():
693
+ if callback_on_step_end is not None:
694
+ callback_kwargs = {}
695
+ for k in callback_on_step_end_tensor_inputs:
696
+ callback_kwargs[k] = locals()[k]
697
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
698
+ latents = callback_outputs.pop("latents", latents)
699
+
700
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
701
+ progress_bar.update()
702
+
703
+
704
+ with torch.no_grad():
705
+ if not output_type == "latent":
706
+ # cast back to fp16 if needed
707
+ if needs_upcasting:
708
+ self.vae.to(dtype=torch.float16)
709
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
710
+ gt = self.decode_latents(original_image_latents[1:2]*self.vae.config.scaling_factor, num_frames, decode_chunk_size)
711
+ else:
712
+ frames = latents
713
+
714
+ self.maybe_free_model_hooks()
715
+
716
+ if not return_dict:
717
+ return frames
718
+
719
+ return StableVideoDiffusionPipelineOutput(frames=frames), gt
720
+
721
+
722
+ # resizing utils
723
+ # TODO: clean up later
724
+ def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
725
+ h, w = input.shape[-2:]
726
+ factors = (h / size[0], w / size[1])
727
+
728
+ # First, we have to determine sigma
729
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
730
+ sigmas = (
731
+ max((factors[0] - 1.0) / 2.0, 0.001),
732
+ max((factors[1] - 1.0) / 2.0, 0.001),
733
+ )
734
+
735
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
736
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
737
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
738
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
739
+
740
+ # Make sure it is odd
741
+ if (ks[0] % 2) == 0:
742
+ ks = ks[0] + 1, ks[1]
743
+
744
+ if (ks[1] % 2) == 0:
745
+ ks = ks[0], ks[1] + 1
746
+
747
+ input = _gaussian_blur2d(input, ks, sigmas)
748
+
749
+ output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
750
+ return output
751
+
752
+
753
+ def _compute_padding(kernel_size):
754
+ """Compute padding tuple."""
755
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
756
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
757
+ if len(kernel_size) < 2:
758
+ raise AssertionError(kernel_size)
759
+ computed = [k - 1 for k in kernel_size]
760
+
761
+ # for even kernels we need to do asymmetric padding :(
762
+ out_padding = 2 * len(kernel_size) * [0]
763
+
764
+ for i in range(len(kernel_size)):
765
+ computed_tmp = computed[-(i + 1)]
766
+
767
+ pad_front = computed_tmp // 2
768
+ pad_rear = computed_tmp - pad_front
769
+
770
+ out_padding[2 * i + 0] = pad_front
771
+ out_padding[2 * i + 1] = pad_rear
772
+
773
+ return out_padding
774
+
775
+
776
+ def _filter2d(input, kernel):
777
+ # prepare kernel
778
+ b, c, h, w = input.shape
779
+ tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
780
+
781
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
782
+
783
+ height, width = tmp_kernel.shape[-2:]
784
+
785
+ padding_shape: List[int] = _compute_padding([height, width])
786
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
787
+
788
+ # kernel and input tensor reshape to align element-wise or batch-wise params
789
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
790
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
791
+
792
+ # convolve the tensor with the kernel.
793
+ output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
794
+
795
+ out = output.view(b, c, h, w)
796
+ return out
797
+
798
+
799
+ def _gaussian(window_size: int, sigma):
800
+ if isinstance(sigma, float):
801
+ sigma = torch.tensor([[sigma]])
802
+
803
+ batch_size = sigma.shape[0]
804
+
805
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
806
+
807
+ if window_size % 2 == 0:
808
+ x = x + 0.5
809
+
810
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
811
+
812
+ return gauss / gauss.sum(-1, keepdim=True)
813
+
814
+
815
+ def _gaussian_blur2d(input, kernel_size, sigma):
816
+ if isinstance(sigma, tuple):
817
+ sigma = torch.tensor([sigma], dtype=input.dtype)
818
+ else:
819
+ sigma = sigma.to(dtype=input.dtype)
820
+
821
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
822
+ bs = sigma.shape[0]
823
+ kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
824
+ kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
825
+ out_x = _filter2d(input, kernel_x[..., None, :])
826
+ out = _filter2d(out_x, kernel_y[..., None])
827
+
828
+ return out
training/svd_runner.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """Script to fine-tune Stable Video Diffusion."""
18
+
19
+ from datetime import datetime
20
+ import logging
21
+ import math
22
+ import os
23
+ import shutil
24
+ from pathlib import Path
25
+
26
+ import accelerate
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn.functional as F
30
+ import torch.utils.checkpoint
31
+ from torch.utils.data import RandomSampler
32
+ import transformers
33
+ from accelerate import Accelerator
34
+ from accelerate.logging import get_logger
35
+ from accelerate.utils import ProjectConfiguration, set_seed
36
+ from huggingface_hub import create_repo, upload_folder
37
+ from packaging import version
38
+ from tqdm.auto import tqdm
39
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
40
+ from validation import valid_net
41
+ import diffusers
42
+ from svd_pipeline import StableVideoDiffusionPipeline
43
+ from diffusers.models.lora import LoRALinearLayer
44
+ from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler, UNetSpatioTemporalConditionModel
45
+ from diffusers.image_processor import VaeImageProcessor
46
+ from diffusers.optimization import get_scheduler
47
+ from diffusers.training_utils import EMAModel
48
+ from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image
49
+ from diffusers.utils.import_utils import is_xformers_available
50
+ from utils import parse_args, FocalStackDataset, OutsidePhotosDataset, rand_log_normal, tensor_to_vae_latent, load_image, _resize_with_antialiasing, encode_image, get_add_time_ids
51
+ import wandb
52
+ import random
53
+ from random import choices
54
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
55
+ check_min_version("0.24.0.dev0")
56
+
57
+ logger = get_logger(__name__, log_level="INFO")
58
+
59
+ import numpy as np
60
+ import PIL.Image
61
+ import torch
62
+ from typing import Callable, Dict, List, Optional, Union
63
+ import os
64
+
65
+
66
+
67
+ def main():
68
+ args = parse_args()
69
+
70
+ #SETUP PYTORCH CUDA - Without this I have memory overflow
71
+ #pytorch 2.4.1 is important for this to work
72
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
73
+
74
+ if not is_wandb_available():
75
+ raise ImportError(
76
+ "Make sure to install wandb if you want to use it for logging during training.")
77
+ import wandb
78
+
79
+
80
+ currentSecond= datetime.now().second
81
+ currentMinute = datetime.now().minute
82
+ currentHour = datetime.now().hour
83
+ currentDay = datetime.now().day
84
+ currentMonth = datetime.now().month
85
+ currentYear = datetime.now().year
86
+
87
+
88
+ if args.non_ema_revision is not None:
89
+ deprecate(
90
+ "non_ema_revision!=None",
91
+ "0.15.0",
92
+ message=(
93
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
94
+ " use `--variant=non_ema` instead."
95
+ ),
96
+ )
97
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
98
+ accelerator_project_config = ProjectConfiguration(
99
+ project_dir=args.output_dir, logging_dir=logging_dir)
100
+ ddp_kwargs = accelerate.DistributedDataParallelKwargs(find_unused_parameters=True)
101
+ accelerator = Accelerator(
102
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
103
+ mixed_precision=args.mixed_precision,
104
+ log_with=args.report_to,
105
+ project_config=accelerator_project_config,
106
+ kwargs_handlers=[ddp_kwargs]
107
+ )
108
+
109
+ accelerator.init_trackers(
110
+ project_name=args.wandb_project,
111
+ init_kwargs={"wandb": { "name" : args.run_name}}
112
+ )
113
+
114
+ generator = torch.Generator(
115
+ device=accelerator.device).manual_seed(args.seed)
116
+
117
+
118
+
119
+
120
+ # Make one log on every process with the configuration for debugging.
121
+ logging.basicConfig(
122
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
123
+ datefmt="%m/%d/%Y %H:%M:%S",
124
+ level=logging.INFO,
125
+ )
126
+ logger.info(accelerator.state, main_process_only=False)
127
+ if accelerator.is_local_main_process:
128
+ transformers.utils.logging.set_verbosity_warning()
129
+ diffusers.utils.logging.set_verbosity_info()
130
+ else:
131
+ transformers.utils.logging.set_verbosity_error()
132
+ diffusers.utils.logging.set_verbosity_error()
133
+
134
+ # If passed along, set the training seed now.
135
+ if args.seed is not None:
136
+ set_seed(args.seed)
137
+
138
+ # Handle the repository creation
139
+ if accelerator.is_main_process:
140
+ if args.output_dir is not None:
141
+ os.makedirs(args.output_dir, exist_ok=True)
142
+
143
+ if args.push_to_hub:
144
+ repo_id = create_repo(
145
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
146
+ ).repo_id
147
+
148
+ # Load img encoder, tokenizer and models.
149
+ feature_extractor = CLIPImageProcessor.from_pretrained(
150
+ args.pretrained_model_name_or_path, subfolder="feature_extractor", revision=args.revision
151
+ )
152
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
153
+ args.pretrained_model_name_or_path, subfolder="image_encoder", revision=args.revision
154
+ )
155
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(
156
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant="fp16")
157
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(
158
+ args.pretrained_model_name_or_path if args.pretrain_unet is None else args.pretrain_unet,
159
+ subfolder="unet",
160
+ low_cpu_mem_usage=True,
161
+ variant="fp16"
162
+ )
163
+
164
+ #unet= UNetSpatioTemporalConditionModel()
165
+
166
+ # Freeze vae and image_encoder
167
+ vae.requires_grad_(False)
168
+ image_encoder.requires_grad_(False)
169
+
170
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
171
+ # as these models are only used for inference, keeping weights in full precision is not required.
172
+ weight_dtype = torch.float32
173
+ if accelerator.mixed_precision == "fp16":
174
+ weight_dtype = torch.float16
175
+ elif accelerator.mixed_precision == "bf16":
176
+ weight_dtype = torch.bfloat16
177
+
178
+ # Move image_encoder and vae to gpu and cast to weight_dtype
179
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
180
+ vae.to(accelerator.device, dtype=weight_dtype)
181
+
182
+ # Create EMA for the unet.
183
+ if args.use_ema:
184
+ ema_unet = EMAModel(unet.parameters(
185
+ ), model_cls=UNetSpatioTemporalConditionModel, model_config=unet.config, use_ema_warmup=True, inv_gamma=1, ower=3/4)
186
+
187
+
188
+
189
+ if args.enable_xformers_memory_efficient_attention:
190
+ if is_xformers_available():
191
+ import xformers
192
+
193
+ xformers_version = version.parse(xformers.__version__)
194
+ if xformers_version == version.parse("0.0.16"):
195
+ logger.warn(
196
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
197
+ )
198
+ unet.enable_xformers_memory_efficient_attention()
199
+ else:
200
+ raise ValueError(
201
+ "xformers is not available. Make sure it is installed correctly")
202
+
203
+ # `accelerate` 0.16.0 will have better support for customized saving
204
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
205
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
206
+ def save_model_hook(models, weights, output_dir):
207
+ if args.use_ema:
208
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
209
+
210
+ for i, model in enumerate(models):
211
+ model.save_pretrained(os.path.join(output_dir, "unet"))
212
+
213
+
214
+ # make sure to pop weight so that corresponding model is not saved again
215
+ weights.pop()
216
+
217
+ def load_model_hook(models, input_dir):
218
+
219
+ if args.use_ema:
220
+ load_model = EMAModel.from_pretrained(os.path.join(
221
+ input_dir, "unet_ema"), UNetSpatioTemporalConditionModel)
222
+ ema_unet.load_state_dict(load_model.state_dict())
223
+ ema_unet.to(accelerator.device)
224
+ del load_model
225
+
226
+ for i in range(len(models)):
227
+ # pop models so that they are not loaded again
228
+ model = models.pop()
229
+
230
+ # load diffusers style into model
231
+ load_model = UNetSpatioTemporalConditionModel.from_pretrained(
232
+ input_dir, subfolder="unet")
233
+ model.register_to_config(**load_model.config)
234
+
235
+ model.load_state_dict(load_model.state_dict())
236
+ del load_model
237
+
238
+ accelerator.register_save_state_pre_hook(save_model_hook)
239
+ accelerator.register_load_state_pre_hook(load_model_hook)
240
+
241
+ if args.gradient_checkpointing:
242
+ unet.enable_gradient_checkpointing()
243
+
244
+ # Enable TF32 for faster training on Ampere GPUs,
245
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
246
+ if args.allow_tf32:
247
+ torch.backends.cuda.matmul.allow_tf32 = True
248
+
249
+ if args.scale_lr:
250
+ args.learning_rate = (
251
+ args.learning_rate * args.gradient_accumulation_steps *
252
+ args.per_gpu_batch_size * accelerator.num_processes
253
+ )
254
+
255
+ optimizer_cls = torch.optim.AdamW
256
+
257
+ parameters_list = []
258
+
259
+
260
+ # Customize the parameters that need to be trained; if necessary, you can uncomment them yourself.
261
+ for name, param in unet.named_parameters():
262
+ parameters_list.append(param)
263
+ if 'temporal_transformer_block' in name: #or 'conv_norm_out' in name or 'conv_out' in name or 'conv_in' in name or 'spatial_res_block' in name or 'up_block' in name:
264
+ parameters_list.append(param)
265
+ param.requires_grad = True
266
+ else:
267
+ param.requires_grad = False
268
+ zero_latent = 0
269
+
270
+
271
+
272
+ optimizer = optimizer_cls(
273
+ parameters_list,
274
+ lr=args.learning_rate,
275
+ betas=(args.adam_beta1, args.adam_beta2),
276
+ weight_decay=args.adam_weight_decay,
277
+ eps=args.adam_epsilon,
278
+ )
279
+
280
+ # DataLoaders creation:
281
+ args.global_batch_size = args.per_gpu_batch_size * accelerator.num_processes
282
+
283
+ if args.photos:
284
+ train_dataset = OutsidePhotosDataset(data_folder=args.data_folder, sample_frames=args.num_frames)
285
+ val_dataset = OutsidePhotosDataset(data_folder=args.data_folder, sample_frames=args.num_frames)
286
+ else:
287
+ train_dataset = FocalStackDataset(args.data_folder, args.splits_dir, sample_frames=args.num_frames, split="train")
288
+ val_dataset = FocalStackDataset(args.data_folder, args.splits_dir, sample_frames=args.num_frames, split="val" if not args.test else "test")
289
+ sampler = RandomSampler(train_dataset)
290
+ train_dataloader = torch.utils.data.DataLoader(
291
+ train_dataset,
292
+ sampler=sampler,
293
+ batch_size=args.per_gpu_batch_size,
294
+ num_workers=args.num_workers,
295
+ drop_last=True
296
+ )
297
+ val_dataloader = torch.utils.data.DataLoader(
298
+ val_dataset,
299
+ batch_size=args.per_gpu_batch_size,
300
+ num_workers=args.num_workers,
301
+ shuffle=False,
302
+ )
303
+
304
+ # Scheduler and math around the number of training steps.
305
+ overrode_max_train_steps = False
306
+ num_update_steps_per_epoch = math.ceil(
307
+ len(train_dataloader) / args.gradient_accumulation_steps)
308
+ if args.max_train_steps is None:
309
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
310
+ overrode_max_train_steps = True
311
+
312
+ lr_scheduler = get_scheduler(
313
+ args.lr_scheduler,
314
+ optimizer=optimizer,
315
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
316
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
317
+ )
318
+
319
+
320
+
321
+
322
+ # Prepare everything with our `accelerator`.
323
+ unet, optimizer, lr_scheduler, train_dataloader, val_dataloader = accelerator.prepare(
324
+ unet, optimizer, lr_scheduler, train_dataloader, val_dataloader
325
+ )
326
+
327
+ if args.use_ema:
328
+ ema_unet.to(accelerator.device)
329
+
330
+
331
+
332
+ # attribute handling for models using DDP
333
+ if isinstance(unet, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
334
+ unet = unet.module
335
+
336
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
337
+ num_update_steps_per_epoch = math.ceil(
338
+ len(train_dataloader) / args.gradient_accumulation_steps)
339
+ if overrode_max_train_steps:
340
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
341
+ # Afterwards we recalculate our number of training epochs
342
+ args.num_train_epochs = math.ceil(
343
+ args.max_train_steps / num_update_steps_per_epoch)
344
+
345
+ # We need to initialize the trackers we use, and also store our configuration.
346
+ # The trackers initializes automatically on the main process.
347
+ if accelerator.is_main_process:
348
+ accelerator.init_trackers("SVDXtend", config=vars(args))
349
+
350
+ # Train!
351
+ total_batch_size = args.per_gpu_batch_size * \
352
+ accelerator.num_processes * args.gradient_accumulation_steps
353
+
354
+ logger.info("***** Running training *****")
355
+ logger.info(f" Num examples = {len(train_dataset)}")
356
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
357
+ logger.info(
358
+ f" Instantaneous batch size per device = {args.per_gpu_batch_size}")
359
+ logger.info(
360
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
361
+ logger.info(
362
+ f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
363
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
364
+ global_step = 0
365
+ first_epoch = 0
366
+
367
+
368
+ # Potentially load in the weights and states from a previous save
369
+ if args.load_from_checkpoint:
370
+
371
+ path = args.load_from_checkpoint
372
+ #
373
+ if path is None:
374
+ accelerator.print(
375
+ f"Checkpoint '{args.load_from_checkpoint}' does not exist. Starting a new training run."
376
+ )
377
+ args.load_from_checkpoint = None
378
+ else:
379
+ accelerator.print(f"Resuming from checkpoint {path}")
380
+ accelerator.load_state(path, strict=False)
381
+ global_step = int(os.path.basename(path).split("-")[1])
382
+
383
+ resume_global_step = global_step * args.gradient_accumulation_steps
384
+ first_epoch = global_step // num_update_steps_per_epoch
385
+
386
+ resume_step = resume_global_step % (
387
+ num_update_steps_per_epoch * args.gradient_accumulation_steps)
388
+
389
+ # Only show the progress bar once on each machine.
390
+ progress_bar = tqdm(range(global_step, args.max_train_steps),
391
+ disable=not accelerator.is_local_main_process)
392
+ progress_bar.set_description("Steps")
393
+
394
+ # print("ARGS PHOTOS: ", args.photos)
395
+ # if args.photos:
396
+ # print("MAKING OUTSIDE PHOTOS DATASET")
397
+ # train_dataset = OutsidePhotosDataset(data_folder=args.data_folder, sample_frames=args.num_frames)
398
+ # val_dataset = OutsidePhotosDataset(data_folder=args.data_folder, sample_frames=args.num_frames)
399
+
400
+ # sampler = RandomSampler(train_dataset)
401
+ # train_dataloader = torch.utils.data.DataLoader(
402
+ # train_dataset,
403
+ # sampler=sampler,
404
+ # batch_size=args.per_gpu_batch_size,
405
+ # num_workers=args.num_workers,
406
+ # drop_last=True
407
+ # )
408
+ # val_dataloader = torch.utils.data.DataLoader(
409
+ # val_dataset,
410
+ # batch_size=args.per_gpu_batch_size,
411
+ # num_workers=args.num_workers,
412
+ # shuffle=False,
413
+ # )
414
+
415
+ # train_dataloader, val_dataloader = accelerator.prepare(
416
+ # train_dataloader, val_dataloader)
417
+ if args.test:
418
+ first_epoch = 0 #just so I enter loop for test (regardless of training iterations)
419
+
420
+ for epoch in range(first_epoch, args.num_train_epochs):
421
+ train_loss = 0.0
422
+ for step, batch in enumerate(train_dataloader):
423
+ unet.train()
424
+ if not args.test:
425
+ with accelerator.accumulate(unet):
426
+ # first, convert images to latent space.
427
+ pixel_values = batch["pixel_values"].to(weight_dtype).to(
428
+ accelerator.device, non_blocking=True
429
+ )
430
+
431
+
432
+ conditional_pixel_values = pixel_values
433
+ latents = tensor_to_vae_latent(pixel_values, vae, otype="sample")
434
+
435
+ noise = torch.randn_like(latents)
436
+ bsz = latents.shape[0]
437
+
438
+ cond_sigmas = rand_log_normal(shape=[bsz,], loc=-3.0, scale=0.5).to(latents)
439
+ noise_aug_strength = cond_sigmas[0] # TODO: support batch > 1
440
+ cond_sigmas = cond_sigmas[:, None, None, None, None]
441
+
442
+ conditional_pixel_values = \
443
+ torch.randn_like(conditional_pixel_values) * cond_sigmas + conditional_pixel_values #- Comment this out as I don't want to add noise to the cond
444
+ conditional_latents = tensor_to_vae_latent(conditional_pixel_values, vae, otype="sample")
445
+ conditional_latents = conditional_latents / vae.config.scaling_factor #
446
+
447
+ ##you do noisy conditioning for the
448
+
449
+ # Sample a random timestep for each image
450
+ # P_mean=0.7 P_std=1.6
451
+ sigmas = rand_log_normal(shape=[bsz,], loc=0.7, scale=1.6).to(latents.device)
452
+ # Add noise to the latents according to the noise magnitude at each timestep
453
+ # (this is the forward diffusion process)
454
+ sigmas = sigmas[:, None, None, None, None]
455
+ noisy_latents = latents + noise * sigmas
456
+ timesteps = torch.Tensor(
457
+ [0.25 * sigma.log() for sigma in sigmas]).to(accelerator.device)
458
+
459
+ inp_noisy_latents = noisy_latents / ((sigmas**2 + 1) ** 0.5)
460
+
461
+
462
+ conditioning = args.conditioning
463
+ # Create a tensor of zeros with the same shape as the repeated conditional_latents
464
+ if conditioning == "zero":
465
+ random_frames = [0]
466
+ elif conditioning == "random":
467
+ #choose a random number between 0 and 8 inclusive
468
+ random_frames = [np.random.randint(0, args.num_frames)]
469
+ elif conditioning in ["ablate_position", "ablate_time"] :
470
+ random_frames = [np.random.randint(0, args.num_frames)]
471
+ elif conditioning == "ablate_single_frame":
472
+ input_random_frame = np.random.randint(0, args.num_frames)
473
+ output_random_frame = np.random.randint(0, args.num_frames)
474
+ elif conditioning == "random_single_double_triple":
475
+ num_imgs = random.randint(1, 3)
476
+ random_frames = choices(range(args.num_frames), k=num_imgs)
477
+
478
+ # Get the text embedding for conditioning.
479
+ encoder_hidden_states = encode_image(
480
+ pixel_values[:, random_frames[0], :, :, :].float(),
481
+ feature_extractor, image_encoder, weight_dtype, accelerator)
482
+
483
+ # Here I input a fixed numerical value for 'motion_bucket_id', which is not reasonable.
484
+ # However, I am unable to fully align with the calculation method of the motion score,
485
+ # so I adopted this approach. The same applies to the 'fps' (frames per second).
486
+ conditioning_num = 0
487
+
488
+ if conditioning != "ablate_time":
489
+ conditioning_num = 0
490
+ else:
491
+ conditioning_num = random_frames[0]
492
+
493
+
494
+
495
+ added_time_ids = get_add_time_ids(
496
+ 7, # fixed
497
+ conditioning_num, # motion_bucket_id = 127, fixed
498
+ noise_aug_strength, # noise_aug_strength == cond_sigmas
499
+ encoder_hidden_states.dtype,
500
+ bsz,
501
+ unet
502
+ )
503
+ added_time_ids = added_time_ids.to(latents.device)
504
+
505
+
506
+
507
+ # Conditioning dropout to support classifier-free guidance during inference. For more details
508
+ # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.0args.num_frames800.
509
+ if args.conditioning_dropout_prob is not None:
510
+ random_p = torch.rand(
511
+ bsz, device=latents.device, generator=generator)
512
+ # Sample masks for the edit prompts. - I'm not sure if prompts are used in this model. Sam ewith the text conditioning that comes next.
513
+
514
+ #oh encoder_hidden_states is derived form the image.
515
+
516
+ prompt_mask = random_p < 2 * args.conditioning_dropout_prob
517
+ prompt_mask = prompt_mask.reshape(bsz, 1, 1)
518
+ # Final text conditioning.
519
+ null_conditioning = torch.zeros_like(encoder_hidden_states)
520
+ encoder_hidden_states = torch.where(
521
+ prompt_mask, null_conditioning.unsqueeze(1), encoder_hidden_states.unsqueeze(1))
522
+ # Sample masks for the original images.
523
+ image_mask_dtype = conditional_latents.dtype
524
+ image_mask = 1 - (
525
+ (random_p >= args.conditioning_dropout_prob).to(
526
+ image_mask_dtype)
527
+ * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)
528
+ )
529
+ image_mask = image_mask.reshape(bsz, 1, 1, 1)
530
+ # Final image conditioning.
531
+ conditional_latents = image_mask * conditional_latents #this basically 0s out some of the image latents
532
+
533
+ # Concatenate the `conditional_latents` with the `noisy_latents`.
534
+ # conditional_latents = conditional_latents.unsqueeze(
535
+ # 1).repeat(1, noisy_latents.shape[1], 1, 1, 1)
536
+ if conditioning == "ablate_single_frame":
537
+ #put input frame at first frame
538
+ conditional_latents = conditional_latents[:, 0:1].repeat(1, args.num_frames, 1, 1, 1)
539
+ elif conditioning in ["ablate_position", "ablate_time"]:
540
+
541
+ conditional_latents = conditional_latents[:, random_frames[0]:random_frames[0]+1].repeat(1,args.num_frames, 1, 1, 1)
542
+ else:
543
+ mask = torch.zeros_like(conditional_latents)
544
+ #choose a random frame to allow for the model to learn to focus on different frames (set mask to 1 for that frame)
545
+ mask[:, random_frames] = 1
546
+ conditional_latents = conditional_latents * mask
547
+
548
+
549
+ inp_noisy_latents = torch.cat(
550
+ [inp_noisy_latents, conditional_latents], dim=2)
551
+
552
+ # check https://arxiv.org/abs/2206.00364(the EDM-framework) for more details.
553
+ target = latents
554
+ model_pred = unet(
555
+ inp_noisy_latents, timesteps, encoder_hidden_states, added_time_ids=added_time_ids).sample
556
+
557
+ # Denoise the latents
558
+ c_out = -sigmas / ((sigmas**2 + 1)**0.5)
559
+ c_skip = 1 / (sigmas**2 + 1)
560
+ denoised_latents = model_pred * c_out + c_skip * noisy_latents
561
+ weighing = (1 + sigmas ** 2) * (sigmas**-2.0)
562
+
563
+ # MSE loss
564
+ loss = torch.mean(
565
+ (weighing.float() * (denoised_latents.float() -
566
+ target.float()) ** 2).reshape(target.shape[0], -1),
567
+ dim=1,
568
+ )
569
+ loss = loss.mean()
570
+
571
+ # Gather the losses across all processes for logging (if we use distributed training).
572
+ avg_loss = accelerator.gather(
573
+ loss.repeat(args.per_gpu_batch_size)).mean()
574
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
575
+
576
+ # Backpropagate
577
+ accelerator.backward(loss)
578
+ lr_scheduler.step()
579
+ optimizer.zero_grad()
580
+
581
+
582
+
583
+
584
+ # Checks if the accelerator has performed an optimization step behind the scenes
585
+ if accelerator.sync_gradients:
586
+
587
+ if args.use_ema:
588
+ ema_unet.step(unet.parameters())
589
+ progress_bar.update(1)
590
+ global_step += 1
591
+ accelerator.log({"train_loss": train_loss}, step=global_step)
592
+ train_loss = 0.0
593
+
594
+ if accelerator.is_main_process:
595
+
596
+ # save checkpoints!
597
+ if global_step % args.checkpointing_steps == 0:
598
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
599
+ if args.checkpoints_total_limit is not None:
600
+ checkpoints = os.listdir(args.output_dir)
601
+ checkpoints = [
602
+ d for d in checkpoints if d.startswith("checkpoint")]
603
+ checkpoints = sorted(
604
+ checkpoints, key=lambda x: int(x.split("-")[1]))
605
+
606
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
607
+ if len(checkpoints) >= args.checkpoints_total_limit:
608
+ num_to_remove = len(
609
+ checkpoints) - args.checkpoints_total_limit + 1
610
+ removing_checkpoints = checkpoints[0:num_to_remove]
611
+
612
+ logger.info(
613
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
614
+ )
615
+ logger.info(
616
+ f"removing checkpoints: {', '.join(removing_checkpoints)}")
617
+
618
+ for removing_checkpoint in removing_checkpoints:
619
+ removing_checkpoint = os.path.join(
620
+ args.output_dir, removing_checkpoint)
621
+ shutil.rmtree(removing_checkpoint)
622
+
623
+ save_path = os.path.join(
624
+ args.output_dir, f"checkpoint-{global_step}")
625
+ accelerator.save_state(save_path)
626
+ logger.info(f"Saved state to {save_path}")
627
+ # sample images!
628
+ if args.test or (global_step % args.validation_steps == 0) or (global_step == 1):
629
+ if args.use_ema:
630
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
631
+ ema_unet.store(unet.parameters())
632
+ ema_unet.copy_to(unet.parameters())
633
+
634
+ valid_net(args, val_dataset, val_dataloader, unet, image_encoder, vae, zero_latent, accelerator, global_step, weight_dtype)
635
+ if args.use_ema:
636
+ # Switch back to the original UNet parameters.
637
+ ema_unet.restore(unet.parameters())
638
+ if args.test:
639
+ break
640
+
641
+ torch.cuda.empty_cache()
642
+
643
+
644
+
645
+
646
+ logs = {"step_loss": loss.detach().item(
647
+ ), "lr": lr_scheduler.get_last_lr()[0]}
648
+ progress_bar.set_postfix(**logs)
649
+
650
+ if global_step >= args.max_train_steps:
651
+ break
652
+ if args.test:
653
+ break
654
+ # Create the pipeline using the trained modules and save it.
655
+ accelerator.wait_for_everyone()
656
+ if accelerator.is_main_process and not args.test:
657
+
658
+ pipeline = StableVideoDiffusionPipeline.from_pretrained(
659
+ args.pretrained_model_name_or_path,
660
+ image_encoder=accelerator.unwrap_model(image_encoder),
661
+ vae=accelerator.unwrap_model(vae),
662
+ unet=accelerator.unwrap_model(ema_unet) if args.use_ema else unet,
663
+ revision=args.revision,
664
+ )
665
+ pipeline.save_pretrained(args.output_dir)
666
+
667
+ if args.use_ema:
668
+ ema_unet.copy_to(unet.parameters())
669
+
670
+ if args.push_to_hub:
671
+ upload_folder(
672
+ repo_id=repo_id,
673
+ folder_path=args.output_dir,
674
+ commit_message="End of training",
675
+ ignore_patterns=["step_*", "epoch_*"],
676
+ )
677
+ accelerator.end_training()
678
+
679
+
680
+ if __name__ == "__main__":
681
+ main()
682
+
683
+
training/utils.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from torch.utils.data import Dataset
3
+ import cv2
4
+ import argparse
5
+ import glob
6
+ import random
7
+ import logging
8
+ import torch
9
+ import os
10
+ import numpy as np
11
+ import PIL
12
+ from PIL import Image, ImageDraw
13
+ from einops import rearrange
14
+ from urllib.parse import urlparse
15
+ from diffusers.utils import load_image
16
+ import math
17
+
18
+ # copy from https://github.com/crowsonkb/k-diffusion.git
19
+ def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
20
+ """Draws samples from an lognormal distribution."""
21
+ u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7
22
+ return torch.distributions.Normal(loc, scale).icdf(u).exp()
23
+
24
+ def encode_image(pixel_values, feature_extractor, image_encoder, weight_dtype, accelerator):
25
+ # pixel: [-1, 1]
26
+ pixel_values = _resize_with_antialiasing(pixel_values, (224, 224))
27
+ # We unnormalize it after resizing.
28
+ pixel_values = (pixel_values + 1.0) / 2.0
29
+
30
+ # Normalize the image with for CLIP input
31
+ pixel_values = feature_extractor(
32
+ images=pixel_values,
33
+ do_normalize=True,
34
+ do_center_crop=False,
35
+ do_resize=False,
36
+ do_rescale=False,
37
+ return_tensors="pt",
38
+ ).pixel_values
39
+
40
+ pixel_values = pixel_values.to(
41
+ device=accelerator.device, dtype=weight_dtype)
42
+ image_embeddings = image_encoder(pixel_values).image_embeds
43
+ return image_embeddings
44
+
45
+ def get_add_time_ids(
46
+ fps,
47
+ motion_bucket_id,
48
+ noise_aug_strength,
49
+ dtype,
50
+ batch_size,
51
+ unet
52
+ ):
53
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
54
+
55
+ passed_add_embed_dim = unet.config.addition_time_embed_dim * \
56
+ len(add_time_ids)
57
+ expected_add_embed_dim = unet.add_embedding.linear_1.in_features
58
+
59
+ if expected_add_embed_dim != passed_add_embed_dim:
60
+ raise ValueError(
61
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
62
+ )
63
+
64
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
65
+ add_time_ids = add_time_ids.repeat(batch_size, 1)
66
+ return add_time_ids
67
+
68
+ def find_scale(height, width):
69
+ """
70
+ Finds a scale factor such that the number of pixels is less than 500,000
71
+ and the dimensions are rounded down to the nearest multiple of 64.
72
+
73
+ Args:
74
+ height (int): The original height of the image.
75
+ width (int): The original width of the image.
76
+
77
+ Returns:
78
+ tuple: The scaled height and width as integers.
79
+ """
80
+ max_pixels = 500000
81
+
82
+ # Start with no scaling
83
+ scale = 1.0
84
+
85
+ while True:
86
+ # Calculate the scaled dimensions
87
+ scaled_height = math.floor((height * scale) / 64) * 64
88
+ scaled_width = math.floor((width * scale) / 64) * 64
89
+
90
+ # Check if the scaled dimensions meet the pixel constraint
91
+ if scaled_height * scaled_width <= max_pixels:
92
+ return scaled_height, scaled_width
93
+
94
+ # Reduce the scale slightly
95
+ scale -= 0.01
96
+
97
+ class OutsidePhotosDataset(Dataset):
98
+ def __init__(self, data_folder, width=1024, height=576, sample_frames=9):
99
+ self.data_folder = data_folder
100
+ self.scenes = sorted(glob.glob(os.path.join(data_folder, "*")))
101
+
102
+ #get images that end in .JPG,.jpg, .png
103
+ self.scenes = [scene for scene in self.scenes if scene.endswith(".JPG") or scene.endswith(".jpg") or scene.endswith(".png") or scene.endswith(".jpeg") or scene.endswith(".JPG")]
104
+ #make each scene a tuple anf for each scene, put it 9 times in the tuple - tuple should look like (scene_name, idx (0-8))
105
+
106
+ self.scenes = [(scene, idx) for scene in self.scenes for idx in range(9)]
107
+
108
+
109
+ self.num_scenes = len(self.scenes)
110
+ self.width = width
111
+ self.height = height
112
+ self.sample_frames = sample_frames
113
+ self.icc_profiles = [None]*self.num_scenes
114
+
115
+ def __len__(self):
116
+ return self.num_scenes
117
+
118
+ def __getitem__(self, idx):
119
+ #get the scene and the index
120
+ #create an empty tensor to store the pixel values and place the scene in the tensor (load and resize the image)
121
+
122
+ scene, focal_stack_num = self.scenes[idx]
123
+
124
+ with Image.open(scene) as img:
125
+
126
+ self.icc_profiles[idx] = img.info.get("icc_profile")
127
+ icc_profile = img.info.get("icc_profile")
128
+ if icc_profile is None:
129
+ icc_profile = "none"
130
+ original_pixels = torch.from_numpy(np.array(img)).float().permute(2,0,1)
131
+ original_pixels = original_pixels / 255
132
+ width, height = img.size
133
+ scaled_width, scaled_height = find_scale(width, height)
134
+
135
+ img_resized = img.resize((scaled_width, scaled_height))
136
+ img_tensor = torch.from_numpy(np.array(img_resized)).float()
137
+ img_normalized = img_tensor / 127.5 - 1
138
+ img_normalized = img_normalized.permute(2, 0, 1)
139
+
140
+ pixels = torch.zeros((self.sample_frames, 3, scaled_height, scaled_width))
141
+ pixels[focal_stack_num] = img_normalized
142
+
143
+ return {"pixel_values": pixels, "idx": idx//9, "focal_stack_num": focal_stack_num, "original_pixel_values": original_pixels, 'icc_profile': icc_profile}
144
+
145
+
146
+
147
+
148
+ class FocalStackDataset(Dataset):
149
+ def __init__(self, data_folder: str, splits_dir, split="train", num_samples=100000, width=640, height=896, sample_frames=9): #4.5
150
+ #800*600 - 480000
151
+ #896*672 - 602112
152
+ """
153
+ Args:
154
+ num_samples (int): Number of samples in the dataset.
155
+ channels (int): Number of channels, default is 3 for RGB.
156
+ """
157
+ self.num_samples = num_samples
158
+ self.sample_frames = sample_frames
159
+ # Define the path to the folder containing video frames
160
+ self.data_folder = data_folder
161
+ self.splits_dir = splits_dir
162
+
163
+ size = "midsize"
164
+ # Use glob to find matching folders
165
+ # List to store the desired paths
166
+ rig_directories = []
167
+
168
+ # Walk through the directory
169
+ for root, dirs, files in os.walk(data_folder):
170
+ # Check if the path matches "downscaled/undistorted/Rig*"
171
+ for directory in dirs:
172
+ if directory.startswith("RigCenter") and f"{size}/undistorted" in root.replace("\\", "/"):
173
+ rig_directory = os.path.join(root, directory)
174
+ #check that rig_directory contains all 9 images
175
+ if len(glob.glob(os.path.join(rig_directory, "*.jpg"))) == 9:
176
+ rig_directories.append(rig_directory)
177
+
178
+
179
+ self.scenes = sorted(rig_directories) #sort the files by name
180
+
181
+ if split == "train":
182
+ #shuffle the scenes
183
+ random.shuffle(self.scenes)
184
+ self.split = split
185
+
186
+ debug = False
187
+
188
+
189
+ if debug:
190
+ self.scenes = self.scenes[50:60]
191
+ elif split == "train":
192
+ pkl_file = os.path.join(self.splits_dir, "train_scenes.pkl")
193
+ #load the train scenes
194
+ with open(pkl_file, "rb") as f:
195
+ pkl_scenes = pickle.load(f)
196
+
197
+ #only get scenes that are found in pkl file
198
+ self.scenes = [scene for scene in self.scenes if scene.split('/')[-4] in pkl_scenes]
199
+
200
+ elif split == "val":
201
+ pkl_file = os.path.join(self.splits_dir, "test_scenes.pkl") #use first 10 test scenes for val (just for visualization)
202
+
203
+ #load the test scenes
204
+ with open(pkl_file, "rb") as f:
205
+ pkl_scenes = pickle.load(f)
206
+
207
+ #only get scenes that are found in pkl file
208
+ self.scenes = [scene for scene in self.scenes if scene.split('/')[-4] in pkl_scenes]
209
+ self.scenes = self.scenes[:10]
210
+ else:
211
+ pkl_file = os.path.join(self.splits_dir, "test_scenes.pkl")
212
+
213
+ #load the test scenes
214
+ with open(pkl_file, "rb") as f:
215
+ pkl_scenes = pickle.load(f)
216
+
217
+ #only get scenes that are found in pkl file
218
+ self.scenes = [scene for scene in self.scenes if scene.split('/')[-4] in pkl_scenes]
219
+
220
+
221
+
222
+ if split == "test":
223
+ self.scenes = [(scene, idx) for scene in self.scenes for idx in range(self.sample_frames)]
224
+
225
+ self.num_scenes = len(self.scenes)
226
+
227
+ max_trdata = 0
228
+ if max_trdata > 0:
229
+ self.scenes = self.scenes[:max_trdata]
230
+
231
+ self.data_store = {}
232
+
233
+ logging.info(f'Creating {split} dataset with {self.num_scenes} examples')
234
+
235
+ self.channels = 3
236
+ self.width = width
237
+ self.height = height
238
+
239
+
240
+ def __len__(self):
241
+ return self.num_scenes
242
+
243
+ def __getitem__(self, idx):
244
+ """
245
+ Args:
246
+ idx (int): Index of the sample to return.
247
+
248
+ Returns:
249
+ dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512).
250
+ """
251
+ # Randomly select a folder (representing a video) from the base folder
252
+ if self.split == "test":
253
+ chosen_folder, focal_stack_num = self.scenes[idx]
254
+ else:
255
+ chosen_folder = self.scenes[idx]
256
+ frames = os.listdir(chosen_folder)
257
+ #get only frames that are jpg
258
+ frames = [frame for frame in frames if frame.endswith(".jpg")]
259
+ # Sort the frames by name
260
+ frames.sort()
261
+
262
+ #Pad the frames list out
263
+ selected_frames = frames[:self.sample_frames]
264
+ # Initialize a tensor to store the pixel values
265
+ pixel_values = torch.empty((self.sample_frames, self.channels, self.height, self.width))
266
+
267
+ original_pixel_values = torch.empty((self.sample_frames, self.channels, 896, 640))
268
+
269
+ # Load and process each frame
270
+ for i, frame_name in enumerate(selected_frames):
271
+ frame_path = os.path.join(chosen_folder, frame_name)
272
+ with Image.open(frame_path) as img:
273
+
274
+
275
+ # Resize the image and convert it to a tensor
276
+ img_resized = img.resize((self.width, self.height))
277
+ img_tensor = torch.from_numpy(np.array(img_resized)).float()
278
+ original_img_tensor = torch.from_numpy(np.array(img)).float()
279
+
280
+ # Normalize the image by scaling pixel values to [-1, 1]
281
+ img_normalized = img_tensor / 127.5 - 1
282
+ original_img_normalized = original_img_tensor / 127.5 - 1
283
+
284
+ # Rearrange channels if necessary
285
+ if self.channels == 3:
286
+ img_normalized = img_normalized.permute(
287
+ 2, 0, 1) # For RGB images
288
+ original_img_normalized = original_img_normalized.permute(2, 0, 1)
289
+
290
+ pixel_values[i] = img_normalized
291
+ original_pixel_values[i] = original_img_normalized
292
+
293
+ if self.sample_frames == 10: #special case for 10 frames where we duplicate the 9th frame (sometimes reduced color artifacts)
294
+ pixel_values[9] = pixel_values[8]
295
+ original_pixel_values[9] = original_pixel_values[8]
296
+ out_dict = {'pixel_values': pixel_values, "idx": idx, "original_pixel_values": original_pixel_values}
297
+ if self.split == "test":
298
+ out_dict["focal_stack_num"] = focal_stack_num
299
+ out_dict["idx"] = idx//9
300
+ return out_dict
301
+
302
+ # resizing utils
303
+ # TODO: clean up later
304
+ def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
305
+ h, w = input.shape[-2:]
306
+ factors = (h / size[0], w / size[1])
307
+
308
+ # First, we have to determine sigma
309
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
310
+ sigmas = (
311
+ max((factors[0] - 1.0) / 2.0, 0.001),
312
+ max((factors[1] - 1.0) / 2.0, 0.001),
313
+ )
314
+
315
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
316
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
317
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
318
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
319
+
320
+ # Make sure it is odd
321
+ if (ks[0] % 2) == 0:
322
+ ks = ks[0] + 1, ks[1]
323
+
324
+ if (ks[1] % 2) == 0:
325
+ ks = ks[0], ks[1] + 1
326
+
327
+ input = _gaussian_blur2d(input, ks, sigmas)
328
+
329
+ output = torch.nn.functional.interpolate(
330
+ input, size=size, mode=interpolation, align_corners=align_corners)
331
+ return output
332
+
333
+
334
+ def _compute_padding(kernel_size):
335
+ """Compute padding tuple."""
336
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
337
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
338
+ if len(kernel_size) < 2:
339
+ raise AssertionError(kernel_size)
340
+ computed = [k - 1 for k in kernel_size]
341
+
342
+ # for even kernels we need to do asymmetric padding :(
343
+ out_padding = 2 * len(kernel_size) * [0]
344
+
345
+ for i in range(len(kernel_size)):
346
+ computed_tmp = computed[-(i + 1)]
347
+
348
+ pad_front = computed_tmp // 2
349
+ pad_rear = computed_tmp - pad_front
350
+
351
+ out_padding[2 * i + 0] = pad_front
352
+ out_padding[2 * i + 1] = pad_rear
353
+
354
+ return out_padding
355
+
356
+
357
+ def _filter2d(input, kernel):
358
+ # prepare kernel
359
+ b, c, h, w = input.shape
360
+ tmp_kernel = kernel[:, None, ...].to(
361
+ device=input.device, dtype=input.dtype)
362
+
363
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
364
+
365
+ height, width = tmp_kernel.shape[-2:]
366
+
367
+ padding_shape: list[int] = _compute_padding([height, width])
368
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
369
+
370
+ # kernel and input tensor reshape to align element-wise or batch-wise params
371
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
372
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
373
+
374
+ # convolve the tensor with the kernel.
375
+ output = torch.nn.functional.conv2d(
376
+ input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
377
+
378
+ out = output.view(b, c, h, w)
379
+ return out
380
+
381
+
382
+ def _gaussian(window_size: int, sigma):
383
+ if isinstance(sigma, float):
384
+ sigma = torch.tensor([[sigma]])
385
+
386
+ batch_size = sigma.shape[0]
387
+
388
+ x = (torch.arange(window_size, device=sigma.device,
389
+ dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
390
+
391
+ if window_size % 2 == 0:
392
+ x = x + 0.5
393
+
394
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
395
+
396
+ return gauss / gauss.sum(-1, keepdim=True)
397
+
398
+
399
+ def _gaussian_blur2d(input, kernel_size, sigma):
400
+ if isinstance(sigma, tuple):
401
+ sigma = torch.tensor([sigma], dtype=input.dtype)
402
+ else:
403
+ sigma = sigma.to(dtype=input.dtype)
404
+
405
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
406
+ bs = sigma.shape[0]
407
+ kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
408
+ kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
409
+ out_x = _filter2d(input, kernel_x[..., None, :])
410
+ out = _filter2d(out_x, kernel_y[..., None])
411
+
412
+ return out
413
+
414
+
415
+ def export_to_video(video_frames, output_video_path, fps):
416
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
417
+ h, w, _ = video_frames[0].shape
418
+ video_writer = cv2.VideoWriter(
419
+ output_video_path, fourcc, fps=fps, frameSize=(w, h))
420
+ for i in range(len(video_frames)):
421
+ img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
422
+ video_writer.write(img)
423
+
424
+
425
+ def export_to_gif(frames, output_gif_path, fps):
426
+ """
427
+ Export a list of frames to a GIF.
428
+
429
+ Args:
430
+ - frames (list): List of frames (as numpy arrays or PIL Image objects).
431
+ - output_gif_path (str): Path to save the output GIF.
432
+ - duration_ms (int): Duration of each frame in milliseconds.
433
+
434
+ """
435
+ # Convert numpy arrays to PIL Images if needed
436
+ pil_frames = [Image.fromarray(frame) if isinstance(
437
+ frame, np.ndarray) else frame for frame in frames]
438
+
439
+ pil_frames[0].save(output_gif_path.replace('.mp4', '.gif'),
440
+ format='GIF',
441
+ append_images=pil_frames[1:],
442
+ save_all=True,
443
+ duration=500,
444
+ loop=0)
445
+
446
+
447
+ def tensor_to_vae_latent(t, vae, otype="sample"):
448
+ video_length = t.shape[1]
449
+
450
+ t = rearrange(t, "b f c h w -> (b f) c h w")
451
+ if otype == "sample":
452
+ latents = vae.encode(t).latent_dist.sample()
453
+ else:
454
+ latents = vae.encode(t).latent_dist.mode()
455
+ latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length)
456
+ latents = latents * vae.config.scaling_factor
457
+
458
+ return latents
459
+
460
+ import yaml
461
+ def parse_config(config_path="config.yaml"):
462
+ with open(config_path, "r") as f:
463
+ config = yaml.safe_load(f)
464
+
465
+ # handle distributed training rank
466
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
467
+ if env_local_rank != -1 and env_local_rank != config.get("local_rank", -1):
468
+ config["local_rank"] = env_local_rank
469
+
470
+ # default fallback: non_ema_revision = revision
471
+ if config.get("non_ema_revision") is None:
472
+ config["non_ema_revision"] = config.get("revision")
473
+
474
+ return config
475
+
476
+ def parse_args():
477
+ parser = argparse.ArgumentParser(description="SVD Training Script")
478
+ parser.add_argument(
479
+ "--config",
480
+ type=str,
481
+ default="svd/scripts/training/configs/stage1_base.yaml",
482
+ help="Path to the config file.",
483
+ )
484
+
485
+ args = parser.parse_args()
486
+
487
+
488
+ # load YAML and merge into args
489
+ config = parse_config(args.config)
490
+ # combine yaml + command line args (command line has priority)
491
+ for k, v in vars(args).items():
492
+ if v is not None:
493
+ config[k] = v
494
+
495
+ # convert dict to argparse.Namespace for downstream compatibility
496
+ args = argparse.Namespace(**config)
497
+
498
+ print("OUTPUT DIR: ", args.output_dir)
499
+ return args
500
+
501
+
502
+ def download_image(url):
503
+ original_image = (
504
+ lambda image_url_or_path: load_image(image_url_or_path)
505
+ if urlparse(image_url_or_path).scheme
506
+ else PIL.Image.open(image_url_or_path).convert("RGB")
507
+ )(url)
508
+ return original_image
509
+
training/validation.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchmetrics import MetricCollection
2
+ from svd_pipeline import StableVideoDiffusionPipeline
3
+ from accelerate.logging import get_logger
4
+ import os
5
+ from utils import load_image
6
+ import torch
7
+ import numpy as np
8
+ import videoio
9
+ import torchmetrics.image
10
+ import matplotlib.image
11
+ from PIL import Image
12
+
13
+ logger = get_logger(__name__, log_level="INFO")
14
+
15
+
16
+ def valid_net(args, val_dataset, val_dataloader, unet, image_encoder, vae, zero, accelerator, global_step, weight_dtype):
17
+ logger.info(
18
+ f"Running validation... \n Generating {args.num_validation_images} videos."
19
+ )
20
+
21
+ # The models need unwrapping because for compatibility in distributed training mode.
22
+
23
+ pipeline = StableVideoDiffusionPipeline.from_pretrained(
24
+ args.pretrained_model_name_or_path,
25
+ unet=unet,
26
+ image_encoder=image_encoder,
27
+ vae=vae,
28
+ revision=args.revision,
29
+ torch_dtype=weight_dtype,
30
+ )
31
+
32
+ pipeline.set_progress_bar_config(disable=True)
33
+
34
+ # run inference
35
+ val_save_dir = os.path.join(
36
+ args.output_dir, "validation_images")
37
+
38
+ print("Validation images will be saved to ", val_save_dir)
39
+
40
+ os.makedirs(val_save_dir, exist_ok=True)
41
+
42
+
43
+ num_frames = args.num_frames
44
+ unet.eval()
45
+ with torch.autocast(
46
+ str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
47
+ ):
48
+ for batch in val_dataloader:
49
+ #clear gradients (the torch no grad is the magic that makes this work)
50
+ with torch.no_grad():
51
+ torch.cuda.empty_cache()
52
+
53
+ pixel_values = batch["pixel_values"].to(accelerator.device)
54
+ original_pixel_values = batch['original_pixel_values'].to(accelerator.device)
55
+ idx = batch["idx"].to(accelerator.device)
56
+ if "focal_stack_num" in batch:
57
+ focal_stack_num = batch["focal_stack_num"][0].item()
58
+ else:
59
+ focal_stack_num = None
60
+
61
+ svd_output, gt_frames = pipeline(
62
+ pixel_values,
63
+ height=pixel_values.shape[3],
64
+ width=pixel_values.shape[4],
65
+ num_frames=args.num_frames,
66
+ decode_chunk_size=8,
67
+ motion_bucket_id=0 if args.conditioning != "ablate_time" else focal_stack_num,
68
+ min_guidance_scale=1.5,
69
+ max_guidance_scale=1.5,
70
+ reconstruction_guidance_scale=args.reconstruction_guidance,
71
+ fps=7,
72
+ noise_aug_strength=0,
73
+ accelerator=accelerator,
74
+ weight_dtype=weight_dtype,
75
+ conditioning = args.conditioning,
76
+ focal_stack_num = focal_stack_num,
77
+ zero=zero
78
+ # generator=generator,
79
+ )
80
+ video_frames = svd_output.frames[0]
81
+ gt_frames = gt_frames[0]
82
+
83
+
84
+ with torch.no_grad():
85
+
86
+ if args.num_frames == 10:
87
+ #remove a frame at end from video_frames and gt_frames
88
+ video_frames = video_frames[:, :-1]
89
+ gt_frames = gt_frames[:, :-1]
90
+ original_pixel_values = original_pixel_values[:, :-1]
91
+
92
+ if len(original_pixel_values.shape) == 5:
93
+ pixel_values = original_pixel_values[0] #assuming batch size is 1
94
+ else:
95
+ pixel_values = original_pixel_values.repeat(num_frames, 1, 1, 1)
96
+ pixel_values_normalized = pixel_values*0.5 + 0.5
97
+ pixel_values_normalized = torch.clamp(pixel_values_normalized,0,1)
98
+
99
+
100
+
101
+
102
+ video_frames_normalized = video_frames*0.5 + 0.5
103
+ video_frames_normalized = torch.clamp(video_frames_normalized,0,1)
104
+ video_frames_normalized = video_frames_normalized.permute(1,0,2,3)
105
+
106
+
107
+ gt_frames = torch.clamp(gt_frames,0,1)
108
+ gt_frames = gt_frames.permute(1,0,2,3)
109
+
110
+ #RESIZE images
111
+ video_frames_normalized = torch.nn.functional.interpolate(video_frames_normalized, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
112
+ gt_frames = torch.nn.functional.interpolate(gt_frames, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
113
+ pixel_values_normalized = torch.nn.functional.interpolate(pixel_values_normalized, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
114
+
115
+ os.makedirs(os.path.join(val_save_dir, f"position_{focal_stack_num}/videos"), exist_ok=True)
116
+ videoio.videosave(os.path.join(
117
+ val_save_dir,
118
+ f"position_{focal_stack_num}/videos/step_{global_step}_val_img_{idx[0].item()}.mp4",
119
+ ), video_frames_normalized.permute(0,2,3,1).cpu().numpy(), fps=5)
120
+
121
+ if args.test:
122
+ #save images
123
+ os.makedirs(os.path.join(val_save_dir, f"position_{focal_stack_num}/images"), exist_ok=True)
124
+ if not args.photos:
125
+ for i in range(num_frames):
126
+ matplotlib.image.imsave(os.path.join(val_save_dir, f"position_{focal_stack_num}/images/img_{idx[0].item()}_frame_{i}.png"), video_frames_normalized[i].permute(1,2,0).cpu().numpy())
127
+ else:
128
+ for i in range(num_frames):
129
+ #use Pillow to save images
130
+ img = Image.fromarray((video_frames_normalized[i].permute(1,2,0).cpu().numpy()*255).astype(np.uint8))
131
+ #use index to assign icc profile to img
132
+ if batch['icc_profile'][0] != "none":
133
+ img.info['icc_profile'] = batch['icc_profile'][0]
134
+ img.save(os.path.join(val_save_dir, f"position_{focal_stack_num}/images/img_{idx[0].item()}_frame_{i}.png"))
135
+ del video_frames
136
+
137
+ accelerator.wait_for_everyone()
138
+
139
+ #clear gradients (the torch no grad is the magic that makes this work)
140
+ with torch.no_grad():
141
+ torch.cuda.empty_cache()
142
+
143
+ del pipeline
144
+
145
+ accelerator.wait_for_everyone() #this is really important and we need to make sure everyone is leaving at the same time