Ditto / app.py
QingyanBai's picture
Update app.py
dace734 verified
import os
import cv2
import numpy as np
import random
import sys
import subprocess
from typing import Sequence, Mapping, Any, Union
import torch
from tqdm import tqdm
import argparse
import json
import logging
import shutil
import gradio as gr
import spaces
from huggingface_hub import snapshot_download
import time
import traceback
from utils import get_path_after_pexel
LOCAL_GRADIO_TMP = os.path.abspath("./gradio_tmp")
os.makedirs(LOCAL_GRADIO_TMP, exist_ok=True)
os.environ["GRADIO_TEMP_DIR"] = LOCAL_GRADIO_TMP
HF_REPOS = {
"QingyanBai/Ditto_models": ["models_comfy/ditto_global_comfy.safetensors"],
"Kijai/WanVideo_comfy": [
"Wan2_1-T2V-14B_fp8_e4m3fn.safetensors",
"Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors",
"Wan2_1_VAE_bf16.safetensors",
"umt5-xxl-enc-bf16.safetensors",
],
}
MODELS_ROOT = os.path.abspath(os.path.join(os.getcwd(), "models"))
PATHS = {
"diffusion_model": os.path.join(MODELS_ROOT, "diffusion_models"),
"vae_wan": os.path.join(MODELS_ROOT, "vae", "wan"),
"loras": os.path.join(MODELS_ROOT, "loras"),
"text_encoders": os.path.join(MODELS_ROOT, "text_encoders"),
}
REQUIRED_FILES = [
("Wan2_1-T2V-14B_fp8_e4m3fn.safetensors", "diffusion_model"),
("ditto_global_comfy.safetensors", "diffusion_model"),
("Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors", "loras"),
("Wan2_1_VAE_bf16.safetensors", "vae_wan"),
("umt5-xxl-enc-bf16.safetensors", "text_encoders"),
]
def ensure_dir(path: str) -> None:
os.makedirs(path, exist_ok=True)
def ensure_models() -> None:
for filename, key in REQUIRED_FILES:
target_dir = PATHS[key]
ensure_dir(target_dir)
target_path = os.path.join(target_dir, filename)
ready_flag = os.path.join(target_dir, f"{filename}.READY")
if os.path.exists(target_path) and os.path.getsize(target_path) > 0:
open(ready_flag, "a").close()
continue
repo_id = None
repo_file_path = None
for repo, files in HF_REPOS.items():
for file_path in files:
if filename in file_path:
repo_id = repo
repo_file_path = file_path
break
if repo_id:
break
if repo_id is None:
raise RuntimeError(f"Could not find repository for file: {filename}")
print(f"Downloading {filename} from {repo_id} to {target_dir} ...")
snapshot_download(
repo_id=repo_id,
local_dir=target_dir,
local_dir_use_symlinks=False,
allow_patterns=[repo_file_path],
token=os.getenv("HF_TOKEN", None),
)
if not os.path.exists(target_path):
found = []
for root, _, files in os.walk(target_dir):
for f in files:
if f == filename:
found.append(os.path.join(root, f))
if found:
src = found[0]
if src != target_path:
shutil.copy2(src, target_path)
if not os.path.exists(target_path):
raise RuntimeError(f"Failed to download required file: {filename}")
open(ready_flag, "w").close()
print(f"Downloaded and ready: {target_path}")
ensure_models()
def ensure_t5_tokenizer() -> None:
"""
Ensure the local T5 tokenizer folder exists and contains valid files.
If missing or corrupted, download from 'google/umt5-xxl' and save locally
to the exact path expected by the WanVideo wrapper nodes.
"""
try:
script_directory = os.path.dirname(os.path.abspath(__file__))
tokenizer_dir = os.path.join(
script_directory,
"custom_nodes",
"ComfyUI_WanVideoWrapper",
"configs",
"T5_tokenizer",
)
os.makedirs(tokenizer_dir, exist_ok=True)
required_files = [
"tokenizer.json",
"tokenizer_config.json",
"spiece.model",
"special_tokens_map.json",
]
def is_valid(path: str) -> bool:
return os.path.exists(path) and os.path.getsize(path) > 0
all_ok = all(is_valid(os.path.join(tokenizer_dir, f)) for f in required_files)
if all_ok:
print(f"T5 tokenizer ready at: {tokenizer_dir}")
return
print(f"Preparing T5 tokenizer at: {tokenizer_dir} ...")
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(
"google/umt5-xxl",
use_fast=True,
trust_remote_code=False,
)
tok.save_pretrained(tokenizer_dir)
# Re-check
all_ok = all(is_valid(os.path.join(tokenizer_dir, f)) for f in required_files)
if not all_ok:
raise RuntimeError("Tokenizer files not fully prepared after save_pretrained")
print("T5 tokenizer prepared successfully.")
except Exception as e:
print(f"Failed to prepare T5 tokenizer: {e}\n{traceback.format_exc()}")
raise
ensure_t5_tokenizer()
def setup_global_logging_filter():
class MemoryLogFilter(logging.Filter):
def filter(self, record):
msg = record.getMessage()
keywords = [
"Allocated memory:",
"Max allocated memory:",
"Max reserved memory:",
"memory=",
"max_memory=",
"max_reserved=",
"Block swap memory summary",
"Transformer blocks on",
"Total memory used by",
"Non-blocking memory transfer"
]
return not any(kw in msg for kw in keywords)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
force=True
)
logging.getLogger().handlers[0].addFilter(MemoryLogFilter())
setup_global_logging_filter()
def tensor_to_video(video_tensor, output_path, fps=20, crf=20):
frames = video_tensor.detach().cpu().numpy()
if frames.dtype != np.uint8:
if frames.max() <= 1.0:
frames = (frames * 255).astype(np.uint8)
else:
frames = frames.astype(np.uint8)
num_frames, height, width, _ = frames.shape
command = [
'ffmpeg',
'-y',
'-f', 'rawvideo',
'-vcodec', 'rawvideo',
'-pix_fmt', 'rgb24',
'-s', f'{width}x{height}',
'-r', str(fps),
'-i', '-',
'-c:v', 'libx264',
'-pix_fmt', 'yuv420p',
'-crf', str(crf),
'-preset', 'medium',
'-r', str(fps),
'-an',
output_path
]
with subprocess.Popen(command, stdin=subprocess.PIPE, stderr=subprocess.PIPE) as proc:
for frame in frames:
proc.stdin.write(frame.tobytes())
proc.stdin.close()
if proc.stderr is not None:
proc.stderr.read()
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
def find_path(name: str, path: str = None) -> str:
if path is None:
path = os.getcwd()
if name in os.listdir(path):
path_name = os.path.join(path, name)
print(f"{name} found: {path_name}")
return path_name
parent_directory = os.path.dirname(path)
if parent_directory == path:
return None
return find_path(name, parent_directory)
def add_comfyui_directory_to_sys_path() -> None:
comfyui_path = find_path("ComfyUI")
if comfyui_path is not None and os.path.isdir(comfyui_path):
if comfyui_path not in sys.path:
sys.path.append(comfyui_path)
print(f"'{comfyui_path}' added to sys.path")
def add_extra_model_paths() -> None:
try:
from main import load_extra_path_config
except ImportError:
print(
"Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead."
)
from utils.extra_config import load_extra_path_config
extra_model_paths = find_path("extra_model_paths.yaml")
if extra_model_paths is not None:
load_extra_path_config(extra_model_paths)
else:
print("Could not find the extra_model_paths config file.")
add_comfyui_directory_to_sys_path()
add_extra_model_paths()
def import_custom_nodes() -> None:
import asyncio
import execution
from nodes import init_extra_nodes
import server
if getattr(import_custom_nodes, "_initialized", False):
return
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server_instance = server.PromptServer(loop)
execution.PromptQueue(server_instance)
init_extra_nodes()
import_custom_nodes._initialized = True
from nodes import NODE_CLASS_MAPPINGS
print(f"Loading custom nodes and models...")
import_custom_nodes()
@spaces.GPU()
def run_pipeline(vpath, prompt, width, height, fps, frame_count, outdir):
try:
import gc
# Clean memory before starting
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
os.makedirs(outdir, exist_ok=True)
with torch.inference_mode():
from custom_nodes.ComfyUI_WanVideoWrapper import nodes as wan_nodes
vhs_loadvideo = NODE_CLASS_MAPPINGS["VHS_LoadVideo"]()
# Set model and settings.
wanvideovacemodelselect = wan_nodes.WanVideoVACEModelSelect()
wanvideovacemodelselect_89 = wanvideovacemodelselect.getvacepath(
vace_model="ditto_global_comfy.safetensors"
)
wanvideoslg = wan_nodes.WanVideoSLG()
wanvideoslg_113 = wanvideoslg.process(
blocks="2",
start_percent=0.20000000000000004,
end_percent=0.7000000000000002,
)
wanvideovaeloader = wan_nodes.WanVideoVAELoader()
wanvideovaeloader_133 = wanvideovaeloader.loadmodel(
model_name="wan/Wan2_1_VAE_bf16.safetensors", precision="bf16"
)
loadwanvideot5textencoder = wan_nodes.LoadWanVideoT5TextEncoder()
loadwanvideot5textencoder_134 = loadwanvideot5textencoder.loadmodel(
model_name="umt5-xxl-enc-bf16.safetensors",
precision="bf16",
load_device="offload_device",
quantization="disabled",
)
wanvideoblockswap = wan_nodes.WanVideoBlockSwap()
wanvideoblockswap_137 = wanvideoblockswap.setargs(
blocks_to_swap=20,
offload_img_emb=False,
offload_txt_emb=False,
use_non_blocking=True,
vace_blocks_to_swap=0,
)
wanvideoloraselect = wan_nodes.WanVideoLoraSelect()
wanvideoloraselect_380 = wanvideoloraselect.getlorapath(
lora="Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors",
strength=1.0,
low_mem_load=False,
)
wanvideomodelloader = wan_nodes.WanVideoModelLoader()
imageresizekjv2 = NODE_CLASS_MAPPINGS["ImageResizeKJv2"]()
wanvideovaceencode = wan_nodes.WanVideoVACEEncode()
wanvideotextencode = wan_nodes.WanVideoTextEncode()
wanvideosampler = wan_nodes.WanVideoSampler()
wanvideodecode = wan_nodes.WanVideoDecode()
wanvideomodelloader_142 = wanvideomodelloader.loadmodel(
model="Wan2_1-T2V-14B_fp8_e4m3fn.safetensors",
base_precision="fp16",
quantization="disabled",
load_device="offload_device",
attention_mode="sdpa",
block_swap_args=get_value_at_index(wanvideoblockswap_137, 0),
lora=get_value_at_index(wanvideoloraselect_380, 0),
vace_model=get_value_at_index(wanvideovacemodelselect_89, 0),
)
fname = os.path.basename(vpath)
fname_clean = os.path.splitext(fname)[0]
vhs_loadvideo_70 = vhs_loadvideo.load_video(
video=vpath,
force_rate=20,
custom_width=width,
custom_height=height,
frame_load_cap=frame_count,
skip_first_frames=1,
select_every_nth=1,
format="AnimateDiff",
unique_id=16696422174153060213,
)
imageresizekjv2_205 = imageresizekjv2.resize(
width=width,
height=height,
upscale_method="area",
keep_proportion="resize",
pad_color="0, 0, 0",
crop_position="center",
divisible_by=8,
device="cpu",
image=get_value_at_index(vhs_loadvideo_70, 0),
)
wanvideovaceencode_29 = wanvideovaceencode.process(
width=width,
height=height,
num_frames=frame_count,
strength=0.9750000000000002,
vace_start_percent=0,
vace_end_percent=1,
tiled_vae=False,
vae=get_value_at_index(wanvideovaeloader_133, 0),
input_frames=get_value_at_index(imageresizekjv2_205, 0),
)
wanvideotextencode_148 = wanvideotextencode.process(
positive_prompt=prompt,
negative_prompt="flickering artifact, jpg artifacts, compression, distortion, morphing, low-res, fake, oversaturated, overexposed, over bright, strange behavior, distorted limbs, unnatural motion, unrealistic anatomy, glitch, extra limbs,",
force_offload=True,
t5=get_value_at_index(loadwanvideot5textencoder_134, 0),
model_to_offload=get_value_at_index(wanvideomodelloader_142, 0),
)
# Clean memory before sampling (most memory-intensive step)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
wanvideosampler_2 = wanvideosampler.process(
steps=4,
cfg=1.2000000000000002,
shift=2.0000000000000004,
seed=random.randint(1, 2 ** 64),
force_offload=True,
scheduler="unipc",
riflex_freq_index=0,
denoise_strength=1,
batched_cfg=False,
rope_function="comfy",
model=get_value_at_index(wanvideomodelloader_142, 0),
image_embeds=get_value_at_index(wanvideovaceencode_29, 0),
text_embeds=get_value_at_index(wanvideotextencode_148, 0),
slg_args=get_value_at_index(wanvideoslg_113, 0),
)
res = wanvideodecode.decode(
enable_vae_tiling=False,
tile_x=272,
tile_y=272,
tile_stride_x=144,
tile_stride_y=128,
vae=get_value_at_index(wanvideovaeloader_133, 0),
samples=get_value_at_index(wanvideosampler_2, 0),
)
save_path = os.path.join(outdir, f'{fname_clean}_edit.mp4')
tensor_to_video(res[0], save_path, fps=fps)
# Clean up memory after generation
del res
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"Done. Saved to: {save_path}")
return save_path
except Exception as e:
err = f"Error: {e}\n{traceback.format_exc()}"
print(err)
# Clean memory on error too
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
raise
@spaces.GPU()
def gradio_infer(vfile, prompt, width, height, fps, frame_count, progress=gr.Progress(track_tqdm=True)):
if vfile is None:
return None, "Please upload the video!", "\n".join(logs)
vpath = vfile if isinstance(vfile, str) else vfile.name
if not os.path.exists(vpath) and hasattr(vfile, "save"):
os.makedirs("uploads", exist_ok=True)
vpath = os.path.join("uploads", os.path.basename(vfile.name))
vfile.save(vpath)
outdir = "results"
os.makedirs(outdir, exist_ok=True)
save_path = run_pipeline(
vpath=vpath,
prompt=prompt,
width=int(width),
height=int(height),
fps=int(fps),
frame_count=int(frame_count),
outdir=outdir,
)
return save_path
def build_interface():
with gr.Blocks(title="Ditto") as demo:
gr.Markdown(
"""# Ditto: Scaling Instruction-Based Video Editing with a High-Quality Synthetic Dataset
<div style="font-size: 1.8rem; line-height: 1.6; margin-bottom: 1rem;">
<a href="https://arxiv.org/abs/2510.15742" target="_blank">📄 Paper</a>
&nbsp; | &nbsp;
<a href="https://ezioby.github.io/Ditto_page/" target="_blank">🌐 Project Page</a>
&nbsp; | &nbsp;
<a href="https://github.com/EzioBy/Ditto/" target="_blank"> 💻 Github Code </a>
&nbsp; | &nbsp;
<a href="https://huggingface.co/QingyanBai/Ditto_models/tree/main" target="_blank">📦 Model Weights</a>
&nbsp; | &nbsp;
<a href="https://huggingface.co/datasets/QingyanBai/Ditto-1M" target="_blank">📊 Dataset</a>
</div>
<b>Note1:</b> The backend of this demo is comfy. Though it runs fast, please note that due to the use of quantized and distilled models, there may be some quality degradation.
<b>Note2:</b> Considering the limited memory, please try test cases with lower resolution and frame count, otherwise it may cause out of memory error (you can also try re-running it).
If you like this project, please consider <a href="https://github.com/EzioBy/Ditto/" target="_blank">starring the repo</a> to motivate us. Thank you!
"""
)
with gr.Column():
with gr.Row():
vfile = gr.Video(label="Input Video", value=os.path.join("input", "dasha.mp4"),
sources="upload", interactive=True)
out_video = gr.Video(label="Result")
prompt = gr.Textbox(label="Editing Instruction", value="Make it in the style of Japanese anime")
with gr.Row():
width = gr.Number(label="Width", value=576, precision=0)
height = gr.Number(label="Height", value=324, precision=0)
fps = gr.Number(label="FPS", value=20, precision=0)
frame_count = gr.Number(label="Frame Count", value=49, precision=0)
run_btn = gr.Button("Run", variant="primary")
run_btn.click(
fn=gradio_infer,
inputs=[vfile, prompt, width, height, fps, frame_count],
outputs=[out_video]
)
examples = [
[
os.path.join("input", "dasha.mp4"),
"Add some fire and flame to the background",
576, 324, 20, 49
],
[
os.path.join("input", "dasha.mp4"),
"Add some snow and flakes to the background",
576, 324, 20, 49
],
[
os.path.join("input", "dasha.mp4"),
"Make it in the style of pencil sketch",
576, 324, 20, 49
],
]
gr.Examples(
examples=examples,
inputs=[vfile, prompt, width, height, fps, frame_count],
label="Examples"
)
return demo
if __name__ == "__main__":
demo = build_interface()
demo.launch()