Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import cv2 | |
| import numpy as np | |
| import random | |
| import sys | |
| import subprocess | |
| from typing import Sequence, Mapping, Any, Union | |
| import torch | |
| from tqdm import tqdm | |
| import argparse | |
| import json | |
| import logging | |
| import shutil | |
| import gradio as gr | |
| import spaces | |
| from huggingface_hub import snapshot_download | |
| import time | |
| import traceback | |
| from utils import get_path_after_pexel | |
| LOCAL_GRADIO_TMP = os.path.abspath("./gradio_tmp") | |
| os.makedirs(LOCAL_GRADIO_TMP, exist_ok=True) | |
| os.environ["GRADIO_TEMP_DIR"] = LOCAL_GRADIO_TMP | |
| HF_REPOS = { | |
| "QingyanBai/Ditto_models": ["models_comfy/ditto_global_comfy.safetensors"], | |
| "Kijai/WanVideo_comfy": [ | |
| "Wan2_1-T2V-14B_fp8_e4m3fn.safetensors", | |
| "Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors", | |
| "Wan2_1_VAE_bf16.safetensors", | |
| "umt5-xxl-enc-bf16.safetensors", | |
| ], | |
| } | |
| MODELS_ROOT = os.path.abspath(os.path.join(os.getcwd(), "models")) | |
| PATHS = { | |
| "diffusion_model": os.path.join(MODELS_ROOT, "diffusion_models"), | |
| "vae_wan": os.path.join(MODELS_ROOT, "vae", "wan"), | |
| "loras": os.path.join(MODELS_ROOT, "loras"), | |
| "text_encoders": os.path.join(MODELS_ROOT, "text_encoders"), | |
| } | |
| REQUIRED_FILES = [ | |
| ("Wan2_1-T2V-14B_fp8_e4m3fn.safetensors", "diffusion_model"), | |
| ("ditto_global_comfy.safetensors", "diffusion_model"), | |
| ("Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors", "loras"), | |
| ("Wan2_1_VAE_bf16.safetensors", "vae_wan"), | |
| ("umt5-xxl-enc-bf16.safetensors", "text_encoders"), | |
| ] | |
| def ensure_dir(path: str) -> None: | |
| os.makedirs(path, exist_ok=True) | |
| def ensure_models() -> None: | |
| for filename, key in REQUIRED_FILES: | |
| target_dir = PATHS[key] | |
| ensure_dir(target_dir) | |
| target_path = os.path.join(target_dir, filename) | |
| ready_flag = os.path.join(target_dir, f"{filename}.READY") | |
| if os.path.exists(target_path) and os.path.getsize(target_path) > 0: | |
| open(ready_flag, "a").close() | |
| continue | |
| repo_id = None | |
| repo_file_path = None | |
| for repo, files in HF_REPOS.items(): | |
| for file_path in files: | |
| if filename in file_path: | |
| repo_id = repo | |
| repo_file_path = file_path | |
| break | |
| if repo_id: | |
| break | |
| if repo_id is None: | |
| raise RuntimeError(f"Could not find repository for file: {filename}") | |
| print(f"Downloading {filename} from {repo_id} to {target_dir} ...") | |
| snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=target_dir, | |
| local_dir_use_symlinks=False, | |
| allow_patterns=[repo_file_path], | |
| token=os.getenv("HF_TOKEN", None), | |
| ) | |
| if not os.path.exists(target_path): | |
| found = [] | |
| for root, _, files in os.walk(target_dir): | |
| for f in files: | |
| if f == filename: | |
| found.append(os.path.join(root, f)) | |
| if found: | |
| src = found[0] | |
| if src != target_path: | |
| shutil.copy2(src, target_path) | |
| if not os.path.exists(target_path): | |
| raise RuntimeError(f"Failed to download required file: {filename}") | |
| open(ready_flag, "w").close() | |
| print(f"Downloaded and ready: {target_path}") | |
| ensure_models() | |
| def ensure_t5_tokenizer() -> None: | |
| """ | |
| Ensure the local T5 tokenizer folder exists and contains valid files. | |
| If missing or corrupted, download from 'google/umt5-xxl' and save locally | |
| to the exact path expected by the WanVideo wrapper nodes. | |
| """ | |
| try: | |
| script_directory = os.path.dirname(os.path.abspath(__file__)) | |
| tokenizer_dir = os.path.join( | |
| script_directory, | |
| "custom_nodes", | |
| "ComfyUI_WanVideoWrapper", | |
| "configs", | |
| "T5_tokenizer", | |
| ) | |
| os.makedirs(tokenizer_dir, exist_ok=True) | |
| required_files = [ | |
| "tokenizer.json", | |
| "tokenizer_config.json", | |
| "spiece.model", | |
| "special_tokens_map.json", | |
| ] | |
| def is_valid(path: str) -> bool: | |
| return os.path.exists(path) and os.path.getsize(path) > 0 | |
| all_ok = all(is_valid(os.path.join(tokenizer_dir, f)) for f in required_files) | |
| if all_ok: | |
| print(f"T5 tokenizer ready at: {tokenizer_dir}") | |
| return | |
| print(f"Preparing T5 tokenizer at: {tokenizer_dir} ...") | |
| from transformers import AutoTokenizer | |
| tok = AutoTokenizer.from_pretrained( | |
| "google/umt5-xxl", | |
| use_fast=True, | |
| trust_remote_code=False, | |
| ) | |
| tok.save_pretrained(tokenizer_dir) | |
| # Re-check | |
| all_ok = all(is_valid(os.path.join(tokenizer_dir, f)) for f in required_files) | |
| if not all_ok: | |
| raise RuntimeError("Tokenizer files not fully prepared after save_pretrained") | |
| print("T5 tokenizer prepared successfully.") | |
| except Exception as e: | |
| print(f"Failed to prepare T5 tokenizer: {e}\n{traceback.format_exc()}") | |
| raise | |
| ensure_t5_tokenizer() | |
| def setup_global_logging_filter(): | |
| class MemoryLogFilter(logging.Filter): | |
| def filter(self, record): | |
| msg = record.getMessage() | |
| keywords = [ | |
| "Allocated memory:", | |
| "Max allocated memory:", | |
| "Max reserved memory:", | |
| "memory=", | |
| "max_memory=", | |
| "max_reserved=", | |
| "Block swap memory summary", | |
| "Transformer blocks on", | |
| "Total memory used by", | |
| "Non-blocking memory transfer" | |
| ] | |
| return not any(kw in msg for kw in keywords) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| force=True | |
| ) | |
| logging.getLogger().handlers[0].addFilter(MemoryLogFilter()) | |
| setup_global_logging_filter() | |
| def tensor_to_video(video_tensor, output_path, fps=20, crf=20): | |
| frames = video_tensor.detach().cpu().numpy() | |
| if frames.dtype != np.uint8: | |
| if frames.max() <= 1.0: | |
| frames = (frames * 255).astype(np.uint8) | |
| else: | |
| frames = frames.astype(np.uint8) | |
| num_frames, height, width, _ = frames.shape | |
| command = [ | |
| 'ffmpeg', | |
| '-y', | |
| '-f', 'rawvideo', | |
| '-vcodec', 'rawvideo', | |
| '-pix_fmt', 'rgb24', | |
| '-s', f'{width}x{height}', | |
| '-r', str(fps), | |
| '-i', '-', | |
| '-c:v', 'libx264', | |
| '-pix_fmt', 'yuv420p', | |
| '-crf', str(crf), | |
| '-preset', 'medium', | |
| '-r', str(fps), | |
| '-an', | |
| output_path | |
| ] | |
| with subprocess.Popen(command, stdin=subprocess.PIPE, stderr=subprocess.PIPE) as proc: | |
| for frame in frames: | |
| proc.stdin.write(frame.tobytes()) | |
| proc.stdin.close() | |
| if proc.stderr is not None: | |
| proc.stderr.read() | |
| def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: | |
| try: | |
| return obj[index] | |
| except KeyError: | |
| return obj["result"][index] | |
| def find_path(name: str, path: str = None) -> str: | |
| if path is None: | |
| path = os.getcwd() | |
| if name in os.listdir(path): | |
| path_name = os.path.join(path, name) | |
| print(f"{name} found: {path_name}") | |
| return path_name | |
| parent_directory = os.path.dirname(path) | |
| if parent_directory == path: | |
| return None | |
| return find_path(name, parent_directory) | |
| def add_comfyui_directory_to_sys_path() -> None: | |
| comfyui_path = find_path("ComfyUI") | |
| if comfyui_path is not None and os.path.isdir(comfyui_path): | |
| if comfyui_path not in sys.path: | |
| sys.path.append(comfyui_path) | |
| print(f"'{comfyui_path}' added to sys.path") | |
| def add_extra_model_paths() -> None: | |
| try: | |
| from main import load_extra_path_config | |
| except ImportError: | |
| print( | |
| "Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead." | |
| ) | |
| from utils.extra_config import load_extra_path_config | |
| extra_model_paths = find_path("extra_model_paths.yaml") | |
| if extra_model_paths is not None: | |
| load_extra_path_config(extra_model_paths) | |
| else: | |
| print("Could not find the extra_model_paths config file.") | |
| add_comfyui_directory_to_sys_path() | |
| add_extra_model_paths() | |
| def import_custom_nodes() -> None: | |
| import asyncio | |
| import execution | |
| from nodes import init_extra_nodes | |
| import server | |
| if getattr(import_custom_nodes, "_initialized", False): | |
| return | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| server_instance = server.PromptServer(loop) | |
| execution.PromptQueue(server_instance) | |
| init_extra_nodes() | |
| import_custom_nodes._initialized = True | |
| from nodes import NODE_CLASS_MAPPINGS | |
| print(f"Loading custom nodes and models...") | |
| import_custom_nodes() | |
| def run_pipeline(vpath, prompt, width, height, fps, frame_count, outdir): | |
| try: | |
| import gc | |
| # Clean memory before starting | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| os.makedirs(outdir, exist_ok=True) | |
| with torch.inference_mode(): | |
| from custom_nodes.ComfyUI_WanVideoWrapper import nodes as wan_nodes | |
| vhs_loadvideo = NODE_CLASS_MAPPINGS["VHS_LoadVideo"]() | |
| # Set model and settings. | |
| wanvideovacemodelselect = wan_nodes.WanVideoVACEModelSelect() | |
| wanvideovacemodelselect_89 = wanvideovacemodelselect.getvacepath( | |
| vace_model="ditto_global_comfy.safetensors" | |
| ) | |
| wanvideoslg = wan_nodes.WanVideoSLG() | |
| wanvideoslg_113 = wanvideoslg.process( | |
| blocks="2", | |
| start_percent=0.20000000000000004, | |
| end_percent=0.7000000000000002, | |
| ) | |
| wanvideovaeloader = wan_nodes.WanVideoVAELoader() | |
| wanvideovaeloader_133 = wanvideovaeloader.loadmodel( | |
| model_name="wan/Wan2_1_VAE_bf16.safetensors", precision="bf16" | |
| ) | |
| loadwanvideot5textencoder = wan_nodes.LoadWanVideoT5TextEncoder() | |
| loadwanvideot5textencoder_134 = loadwanvideot5textencoder.loadmodel( | |
| model_name="umt5-xxl-enc-bf16.safetensors", | |
| precision="bf16", | |
| load_device="offload_device", | |
| quantization="disabled", | |
| ) | |
| wanvideoblockswap = wan_nodes.WanVideoBlockSwap() | |
| wanvideoblockswap_137 = wanvideoblockswap.setargs( | |
| blocks_to_swap=20, | |
| offload_img_emb=False, | |
| offload_txt_emb=False, | |
| use_non_blocking=True, | |
| vace_blocks_to_swap=0, | |
| ) | |
| wanvideoloraselect = wan_nodes.WanVideoLoraSelect() | |
| wanvideoloraselect_380 = wanvideoloraselect.getlorapath( | |
| lora="Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors", | |
| strength=1.0, | |
| low_mem_load=False, | |
| ) | |
| wanvideomodelloader = wan_nodes.WanVideoModelLoader() | |
| imageresizekjv2 = NODE_CLASS_MAPPINGS["ImageResizeKJv2"]() | |
| wanvideovaceencode = wan_nodes.WanVideoVACEEncode() | |
| wanvideotextencode = wan_nodes.WanVideoTextEncode() | |
| wanvideosampler = wan_nodes.WanVideoSampler() | |
| wanvideodecode = wan_nodes.WanVideoDecode() | |
| wanvideomodelloader_142 = wanvideomodelloader.loadmodel( | |
| model="Wan2_1-T2V-14B_fp8_e4m3fn.safetensors", | |
| base_precision="fp16", | |
| quantization="disabled", | |
| load_device="offload_device", | |
| attention_mode="sdpa", | |
| block_swap_args=get_value_at_index(wanvideoblockswap_137, 0), | |
| lora=get_value_at_index(wanvideoloraselect_380, 0), | |
| vace_model=get_value_at_index(wanvideovacemodelselect_89, 0), | |
| ) | |
| fname = os.path.basename(vpath) | |
| fname_clean = os.path.splitext(fname)[0] | |
| vhs_loadvideo_70 = vhs_loadvideo.load_video( | |
| video=vpath, | |
| force_rate=20, | |
| custom_width=width, | |
| custom_height=height, | |
| frame_load_cap=frame_count, | |
| skip_first_frames=1, | |
| select_every_nth=1, | |
| format="AnimateDiff", | |
| unique_id=16696422174153060213, | |
| ) | |
| imageresizekjv2_205 = imageresizekjv2.resize( | |
| width=width, | |
| height=height, | |
| upscale_method="area", | |
| keep_proportion="resize", | |
| pad_color="0, 0, 0", | |
| crop_position="center", | |
| divisible_by=8, | |
| device="cpu", | |
| image=get_value_at_index(vhs_loadvideo_70, 0), | |
| ) | |
| wanvideovaceencode_29 = wanvideovaceencode.process( | |
| width=width, | |
| height=height, | |
| num_frames=frame_count, | |
| strength=0.9750000000000002, | |
| vace_start_percent=0, | |
| vace_end_percent=1, | |
| tiled_vae=False, | |
| vae=get_value_at_index(wanvideovaeloader_133, 0), | |
| input_frames=get_value_at_index(imageresizekjv2_205, 0), | |
| ) | |
| wanvideotextencode_148 = wanvideotextencode.process( | |
| positive_prompt=prompt, | |
| negative_prompt="flickering artifact, jpg artifacts, compression, distortion, morphing, low-res, fake, oversaturated, overexposed, over bright, strange behavior, distorted limbs, unnatural motion, unrealistic anatomy, glitch, extra limbs,", | |
| force_offload=True, | |
| t5=get_value_at_index(loadwanvideot5textencoder_134, 0), | |
| model_to_offload=get_value_at_index(wanvideomodelloader_142, 0), | |
| ) | |
| # Clean memory before sampling (most memory-intensive step) | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| wanvideosampler_2 = wanvideosampler.process( | |
| steps=4, | |
| cfg=1.2000000000000002, | |
| shift=2.0000000000000004, | |
| seed=random.randint(1, 2 ** 64), | |
| force_offload=True, | |
| scheduler="unipc", | |
| riflex_freq_index=0, | |
| denoise_strength=1, | |
| batched_cfg=False, | |
| rope_function="comfy", | |
| model=get_value_at_index(wanvideomodelloader_142, 0), | |
| image_embeds=get_value_at_index(wanvideovaceencode_29, 0), | |
| text_embeds=get_value_at_index(wanvideotextencode_148, 0), | |
| slg_args=get_value_at_index(wanvideoslg_113, 0), | |
| ) | |
| res = wanvideodecode.decode( | |
| enable_vae_tiling=False, | |
| tile_x=272, | |
| tile_y=272, | |
| tile_stride_x=144, | |
| tile_stride_y=128, | |
| vae=get_value_at_index(wanvideovaeloader_133, 0), | |
| samples=get_value_at_index(wanvideosampler_2, 0), | |
| ) | |
| save_path = os.path.join(outdir, f'{fname_clean}_edit.mp4') | |
| tensor_to_video(res[0], save_path, fps=fps) | |
| # Clean up memory after generation | |
| del res | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"Done. Saved to: {save_path}") | |
| return save_path | |
| except Exception as e: | |
| err = f"Error: {e}\n{traceback.format_exc()}" | |
| print(err) | |
| # Clean memory on error too | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| raise | |
| def gradio_infer(vfile, prompt, width, height, fps, frame_count, progress=gr.Progress(track_tqdm=True)): | |
| if vfile is None: | |
| return None, "Please upload the video!", "\n".join(logs) | |
| vpath = vfile if isinstance(vfile, str) else vfile.name | |
| if not os.path.exists(vpath) and hasattr(vfile, "save"): | |
| os.makedirs("uploads", exist_ok=True) | |
| vpath = os.path.join("uploads", os.path.basename(vfile.name)) | |
| vfile.save(vpath) | |
| outdir = "results" | |
| os.makedirs(outdir, exist_ok=True) | |
| save_path = run_pipeline( | |
| vpath=vpath, | |
| prompt=prompt, | |
| width=int(width), | |
| height=int(height), | |
| fps=int(fps), | |
| frame_count=int(frame_count), | |
| outdir=outdir, | |
| ) | |
| return save_path | |
| def build_interface(): | |
| with gr.Blocks(title="Ditto") as demo: | |
| gr.Markdown( | |
| """# Ditto: Scaling Instruction-Based Video Editing with a High-Quality Synthetic Dataset | |
| <div style="font-size: 1.8rem; line-height: 1.6; margin-bottom: 1rem;"> | |
| <a href="https://arxiv.org/abs/2510.15742" target="_blank">📄 Paper</a> | |
| | | |
| <a href="https://ezioby.github.io/Ditto_page/" target="_blank">🌐 Project Page</a> | |
| | | |
| <a href="https://github.com/EzioBy/Ditto/" target="_blank"> 💻 Github Code </a> | |
| | | |
| <a href="https://huggingface.co/QingyanBai/Ditto_models/tree/main" target="_blank">📦 Model Weights</a> | |
| | | |
| <a href="https://huggingface.co/datasets/QingyanBai/Ditto-1M" target="_blank">📊 Dataset</a> | |
| </div> | |
| <b>Note1:</b> The backend of this demo is comfy. Though it runs fast, please note that due to the use of quantized and distilled models, there may be some quality degradation. | |
| <b>Note2:</b> Considering the limited memory, please try test cases with lower resolution and frame count, otherwise it may cause out of memory error (you can also try re-running it). | |
| If you like this project, please consider <a href="https://github.com/EzioBy/Ditto/" target="_blank">starring the repo</a> to motivate us. Thank you! | |
| """ | |
| ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| vfile = gr.Video(label="Input Video", value=os.path.join("input", "dasha.mp4"), | |
| sources="upload", interactive=True) | |
| out_video = gr.Video(label="Result") | |
| prompt = gr.Textbox(label="Editing Instruction", value="Make it in the style of Japanese anime") | |
| with gr.Row(): | |
| width = gr.Number(label="Width", value=576, precision=0) | |
| height = gr.Number(label="Height", value=324, precision=0) | |
| fps = gr.Number(label="FPS", value=20, precision=0) | |
| frame_count = gr.Number(label="Frame Count", value=49, precision=0) | |
| run_btn = gr.Button("Run", variant="primary") | |
| run_btn.click( | |
| fn=gradio_infer, | |
| inputs=[vfile, prompt, width, height, fps, frame_count], | |
| outputs=[out_video] | |
| ) | |
| examples = [ | |
| [ | |
| os.path.join("input", "dasha.mp4"), | |
| "Add some fire and flame to the background", | |
| 576, 324, 20, 49 | |
| ], | |
| [ | |
| os.path.join("input", "dasha.mp4"), | |
| "Add some snow and flakes to the background", | |
| 576, 324, 20, 49 | |
| ], | |
| [ | |
| os.path.join("input", "dasha.mp4"), | |
| "Make it in the style of pencil sketch", | |
| 576, 324, 20, 49 | |
| ], | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[vfile, prompt, width, height, fps, frame_count], | |
| label="Examples" | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_interface() | |
| demo.launch() | |