Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| from typing import Callable, Optional, List | |
| def ordered_halving(val): | |
| bin_str = f"{val:064b}" | |
| bin_flip = bin_str[::-1] | |
| as_int = int(bin_flip, 2) | |
| return as_int / (1 << 64) | |
| def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]: | |
| prev_val = -1 | |
| for i, val in enumerate(window): | |
| val = val % num_frames | |
| if val < prev_val: | |
| return True, i | |
| prev_val = val | |
| return False, -1 | |
| def shift_window_to_start(window: list[int], num_frames: int): | |
| start_val = window[0] | |
| for i in range(len(window)): | |
| # 1) subtract each element by start_val to move vals relative to the start of all frames | |
| # 2) add num_frames and take modulus to get adjusted vals | |
| window[i] = ((window[i] - start_val) + num_frames) % num_frames | |
| def shift_window_to_end(window: list[int], num_frames: int): | |
| # 1) shift window to start | |
| shift_window_to_start(window, num_frames) | |
| end_val = window[-1] | |
| end_delta = num_frames - end_val - 1 | |
| for i in range(len(window)): | |
| # 2) add end_delta to each val to slide windows to end | |
| window[i] = window[i] + end_delta | |
| def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]: | |
| all_indexes = list(range(num_frames)) | |
| for w in windows: | |
| for val in w: | |
| try: | |
| all_indexes.remove(val) | |
| except ValueError: | |
| pass | |
| return all_indexes | |
| def uniform_looped( | |
| step: int = ..., | |
| num_steps: Optional[int] = None, | |
| num_frames: int = ..., | |
| context_size: Optional[int] = None, | |
| context_stride: int = 3, | |
| context_overlap: int = 4, | |
| closed_loop: bool = True, | |
| ): | |
| if num_frames <= context_size: | |
| yield list(range(num_frames)) | |
| return | |
| context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) | |
| for context_step in 1 << np.arange(context_stride): | |
| pad = int(round(num_frames * ordered_halving(step))) | |
| for j in range( | |
| int(ordered_halving(step) * context_step) + pad, | |
| num_frames + pad + (0 if closed_loop else -context_overlap), | |
| (context_size * context_step - context_overlap), | |
| ): | |
| yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)] | |
| #from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) | |
| def uniform_standard( | |
| step: int = ..., | |
| num_steps: Optional[int] = None, | |
| num_frames: int = ..., | |
| context_size: Optional[int] = None, | |
| context_stride: int = 3, | |
| context_overlap: int = 4, | |
| closed_loop: bool = True, | |
| ): | |
| windows = [] | |
| if num_frames <= context_size: | |
| windows.append(list(range(num_frames))) | |
| return windows | |
| context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) | |
| for context_step in 1 << np.arange(context_stride): | |
| pad = int(round(num_frames * ordered_halving(step))) | |
| for j in range( | |
| int(ordered_halving(step) * context_step) + pad, | |
| num_frames + pad + (0 if closed_loop else -context_overlap), | |
| (context_size * context_step - context_overlap), | |
| ): | |
| windows.append([e % num_frames for e in range(j, j + context_size * context_step, context_step)]) | |
| # now that windows are created, shift any windows that loop, and delete duplicate windows | |
| delete_idxs = [] | |
| win_i = 0 | |
| while win_i < len(windows): | |
| # if window is rolls over itself, need to shift it | |
| is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames) | |
| if is_roll: | |
| roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides | |
| shift_window_to_end(windows[win_i], num_frames=num_frames) | |
| # check if next window (cyclical) is missing roll_val | |
| if roll_val not in windows[(win_i+1) % len(windows)]: | |
| # need to insert new window here - just insert window starting at roll_val | |
| windows.insert(win_i+1, list(range(roll_val, roll_val + context_size))) | |
| # delete window if it's not unique | |
| for pre_i in range(0, win_i): | |
| if windows[win_i] == windows[pre_i]: | |
| delete_idxs.append(win_i) | |
| break | |
| win_i += 1 | |
| # reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation | |
| delete_idxs.reverse() | |
| for i in delete_idxs: | |
| windows.pop(i) | |
| return windows | |
| def static_standard( | |
| step: int = ..., | |
| num_steps: Optional[int] = None, | |
| num_frames: int = ..., | |
| context_size: Optional[int] = None, | |
| context_stride: int = 3, | |
| context_overlap: int = 4, | |
| closed_loop: bool = True, | |
| ): | |
| windows = [] | |
| if num_frames <= context_size: | |
| windows.append(list(range(num_frames))) | |
| return windows | |
| # always return the same set of windows | |
| delta = context_size - context_overlap | |
| for start_idx in range(0, num_frames, delta): | |
| # if past the end of frames, move start_idx back to allow same context_length | |
| ending = start_idx + context_size | |
| if ending >= num_frames: | |
| final_delta = ending - num_frames | |
| final_start_idx = start_idx - final_delta | |
| windows.append(list(range(final_start_idx, final_start_idx + context_size))) | |
| break | |
| windows.append(list(range(start_idx, start_idx + context_size))) | |
| return windows | |
| def get_context_scheduler(name: str) -> Callable: | |
| if name == "uniform_looped": | |
| return uniform_looped | |
| elif name == "uniform_standard": | |
| return uniform_standard | |
| elif name == "static_standard": | |
| return static_standard | |
| else: | |
| raise ValueError(f"Unknown context_overlap policy {name}") | |
| def get_total_steps( | |
| scheduler, | |
| timesteps: List[int], | |
| num_steps: Optional[int] = None, | |
| num_frames: int = ..., | |
| context_size: Optional[int] = None, | |
| context_stride: int = 3, | |
| context_overlap: int = 4, | |
| closed_loop: bool = True, | |
| ): | |
| return sum( | |
| len( | |
| list( | |
| scheduler( | |
| i, | |
| num_steps, | |
| num_frames, | |
| context_size, | |
| context_stride, | |
| context_overlap, | |
| ) | |
| ) | |
| ) | |
| for i in range(len(timesteps)) | |
| ) | |