File size: 6,915 Bytes
359fa44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import logging
from typing import Optional
import torch
from comfy_api.input.video_types import VideoInput
from comfy_api.latest import Input
def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
if len(image.shape) == 4:
return image.shape[1], image.shape[2]
elif len(image.shape) == 3:
return image.shape[0], image.shape[1]
else:
raise ValueError("Invalid image tensor shape.")
def validate_image_dimensions(
image: torch.Tensor,
min_width: Optional[int] = None,
max_width: Optional[int] = None,
min_height: Optional[int] = None,
max_height: Optional[int] = None,
):
height, width = get_image_dimensions(image)
if min_width is not None and width < min_width:
raise ValueError(f"Image width must be at least {min_width}px, got {width}px")
if max_width is not None and width > max_width:
raise ValueError(f"Image width must be at most {max_width}px, got {width}px")
if min_height is not None and height < min_height:
raise ValueError(f"Image height must be at least {min_height}px, got {height}px")
if max_height is not None and height > max_height:
raise ValueError(f"Image height must be at most {max_height}px, got {height}px")
def validate_image_aspect_ratio(
image: torch.Tensor,
min_aspect_ratio: Optional[float] = None,
max_aspect_ratio: Optional[float] = None,
):
width, height = get_image_dimensions(image)
aspect_ratio = width / height
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
raise ValueError(f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}")
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
raise ValueError(f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}")
def validate_image_aspect_ratio_range(
image: torch.Tensor,
min_ratio: tuple[float, float], # e.g. (1, 4)
max_ratio: tuple[float, float], # e.g. (4, 1)
*,
strict: bool = True, # True -> (min, max); False -> [min, max]
) -> float:
a1, b1 = min_ratio
a2, b2 = max_ratio
if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0:
raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).")
lo, hi = (a1 / b1), (a2 / b2)
if lo > hi:
lo, hi = hi, lo
a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text
w, h = get_image_dimensions(image)
if w <= 0 or h <= 0:
raise ValueError(f"Invalid image dimensions: {w}x{h}")
ar = w / h
ok = (lo < ar < hi) if strict else (lo <= ar <= hi)
if not ok:
op = "<" if strict else "≤"
raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}")
return ar
def validate_aspect_ratio_closeness(
start_img,
end_img,
min_rel: float,
max_rel: float,
*,
strict: bool = False, # True => exclusive, False => inclusive
) -> None:
w1, h1 = get_image_dimensions(start_img)
w2, h2 = get_image_dimensions(end_img)
if min(w1, h1, w2, h2) <= 0:
raise ValueError("Invalid image dimensions")
ar1 = w1 / h1
ar2 = w2 / h2
# Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1)
closeness = max(ar1, ar2) / min(ar1, ar2)
limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25
if (closeness >= limit) if strict else (closeness > limit):
raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}–{max_rel}.")
def validate_video_dimensions(
video: Input.Video,
min_width: Optional[int] = None,
max_width: Optional[int] = None,
min_height: Optional[int] = None,
max_height: Optional[int] = None,
):
try:
width, height = video.get_dimensions()
except Exception as e:
logging.error("Error getting dimensions of video: %s", e)
return
if min_width is not None and width < min_width:
raise ValueError(f"Video width must be at least {min_width}px, got {width}px")
if max_width is not None and width > max_width:
raise ValueError(f"Video width must be at most {max_width}px, got {width}px")
if min_height is not None and height < min_height:
raise ValueError(f"Video height must be at least {min_height}px, got {height}px")
if max_height is not None and height > max_height:
raise ValueError(f"Video height must be at most {max_height}px, got {height}px")
def validate_video_duration(
video: Input.Video,
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
):
try:
duration = video.get_duration()
except Exception as e:
logging.error("Error getting duration of video: %s", e)
return
epsilon = 0.0001
if min_duration is not None and min_duration - epsilon > duration:
raise ValueError(f"Video duration must be at least {min_duration}s, got {duration}s")
if max_duration is not None and duration > max_duration + epsilon:
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
def get_number_of_images(images):
if isinstance(images, torch.Tensor):
return images.shape[0] if images.ndim >= 4 else 1
return len(images)
def validate_audio_duration(
audio: Input.Audio,
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
) -> None:
sr = int(audio["sample_rate"])
dur = int(audio["waveform"].shape[-1]) / sr
eps = 1.0 / sr
if min_duration is not None and dur + eps < min_duration:
raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s")
if max_duration is not None and dur - eps > max_duration:
raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s")
def validate_string(
string: str,
strip_whitespace=True,
field_name="prompt",
min_length=None,
max_length=None,
):
if string is None:
raise Exception(f"Field '{field_name}' cannot be empty.")
if strip_whitespace:
string = string.strip()
if min_length and len(string) < min_length:
raise Exception(
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
)
if max_length and len(string) > max_length:
raise Exception(
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
)
def validate_container_format_is_mp4(video: VideoInput) -> None:
"""Validates video container format is MP4."""
container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
|