import importlib.metadata import torch import logging from tqdm import tqdm logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') log = logging.getLogger(__name__) from accelerate.utils import set_module_tensor_to_device def check_diffusers_version(): try: version = importlib.metadata.version('diffusers') required_version = '0.31.0' if version < required_version: raise AssertionError(f"diffusers version {version} is installed, but version {required_version} or higher is required.") except importlib.metadata.PackageNotFoundError: raise AssertionError("diffusers is not installed.") def print_memory(device): memory = torch.cuda.memory_allocated(device) / 1024**3 max_memory = torch.cuda.max_memory_allocated(device) / 1024**3 max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 log.info(f"Allocated memory: {memory=:.3f} GB") log.info(f"Max allocated memory: {max_memory=:.3f} GB") log.info(f"Max reserved memory: {max_reserved=:.3f} GB") #memory_summary = torch.cuda.memory_summary(device=device, abbreviated=False) #log.info(f"Memory Summary:\n{memory_summary}") def get_module_memory_mb(module): memory = 0 for param in module.parameters(): if param.data is not None: memory += param.nelement() * param.element_size() return memory / (1024 * 1024) # Convert to MB def apply_lora(model, device_to, transformer_load_device, params_to_keep=None, dtype=None, base_dtype=None, state_dict=None, low_mem_load=False): to_load = [] for n, m in model.model.named_modules(): params = [] skip = False for name, param in m.named_parameters(recurse=False): params.append(name) for name, param in m.named_parameters(recurse=True): if name not in params: skip = True # skip random weights in non leaf modules break if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): to_load.append((n, m, params)) to_load.sort(reverse=True) for x in tqdm(to_load, desc="Loading model and applying LoRA weights:", leave=True): name = x[0] m = x[1] params = x[2] if hasattr(m, "comfy_patched_weights"): if m.comfy_patched_weights == True: continue for param in params: name = name.replace("._orig_mod.", ".") # torch compiled modules have this prefix if low_mem_load: dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype if "patch_embedding" in name: dtype_to_use = torch.float32 if name.startswith("diffusion_model."): name_no_prefix = name[len("diffusion_model."):] key = "{}.{}".format(name_no_prefix, param) try: set_module_tensor_to_device(model.model.diffusion_model, key, device=transformer_load_device, dtype=dtype_to_use, value=state_dict[key]) except: continue model.patch_weight_to_device("{}.{}".format(name, param), device_to=device_to) if low_mem_load: try: set_module_tensor_to_device(model.model.diffusion_model, key, device=transformer_load_device, dtype=dtype_to_use, value=model.model.diffusion_model.state_dict()[key]) except: continue m.comfy_patched_weights = True model.current_weight_patches_uuid = model.patches_uuid if low_mem_load: for name, param in model.model.diffusion_model.named_parameters(): if param.device != transformer_load_device: dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype if "patch_embedding" in name: dtype_to_use = torch.float32 try: set_module_tensor_to_device(model.model.diffusion_model, name, device=transformer_load_device, dtype=dtype_to_use, value=state_dict[name]) except: continue return model # from https://github.com/cubiq/ComfyUI_IPAdapter_plus/blob/9d076a3df0d2763cef5510ec5ab807f6632c39f5/utils.py#L181 def split_tiles(embeds, num_split): _, H, W, _ = embeds.shape out = [] for x in embeds: x = x.unsqueeze(0) h, w = H // num_split, W // num_split x_split = torch.cat([x[:, i*h:(i+1)*h, j*w:(j+1)*w, :] for i in range(num_split) for j in range(num_split)], dim=0) out.append(x_split) x_split = torch.stack(out, dim=0) return x_split def merge_hiddenstates(x, tiles): chunk_size = tiles*tiles x = x.split(chunk_size) out = [] for embeds in x: num_tiles = embeds.shape[0] tile_size = int((embeds.shape[1]-1) ** 0.5) grid_size = int(num_tiles ** 0.5) # Extract class tokens class_tokens = embeds[:, 0, :] # Save class tokens: [num_tiles, embeds[-1]] avg_class_token = class_tokens.mean(dim=0, keepdim=True).unsqueeze(0) # Average token, shape: [1, 1, embeds[-1]] patch_embeds = embeds[:, 1:, :] # Shape: [num_tiles, tile_size^2, embeds[-1]] reshaped = patch_embeds.reshape(grid_size, grid_size, tile_size, tile_size, embeds.shape[-1]) merged = torch.cat([torch.cat([reshaped[i, j] for j in range(grid_size)], dim=1) for i in range(grid_size)], dim=0) merged = merged.unsqueeze(0) # Shape: [1, grid_size*tile_size, grid_size*tile_size, embeds[-1]] # Pool to original size pooled = torch.nn.functional.adaptive_avg_pool2d(merged.permute(0, 3, 1, 2), (tile_size, tile_size)).permute(0, 2, 3, 1) flattened = pooled.reshape(1, tile_size*tile_size, embeds.shape[-1]) # Add back the class token with_class = torch.cat([avg_class_token, flattened], dim=1) # Shape: original shape out.append(with_class) out = torch.cat(out, dim=0) return out from comfy.clip_vision import clip_preprocess, ClipVisionModel def clip_encode_image_tiled(clip_vision, image, tiles=1, ratio=1.0): embeds = encode_image_(clip_vision, image) tiles = min(tiles, 16) if tiles > 1: # split in tiles image_split = split_tiles(image, tiles) # get the embeds for each tile embeds_split = {} for i in image_split: encoded = encode_image_(clip_vision, i) if not hasattr(embeds_split, "last_hidden_state"): embeds_split["last_hidden_state"] = encoded else: embeds_split["last_hidden_state"] = torch.cat(embeds_split["last_hidden_state"], encoded, dim=0) embeds_split['last_hidden_state'] = merge_hiddenstates(embeds_split['last_hidden_state'], tiles) if embeds.shape[0] > 1: # if we have more than one image we need to average the embeddings for consistency embeds = embeds * ratio + embeds_split['last_hidden_state']*(1-ratio) else: # otherwise we can concatenate them, they can be averaged later embeds = torch.cat([embeds * ratio, embeds_split['last_hidden_state']]) return embeds def encode_image_(clip_vision, image): if isinstance(clip_vision, ClipVisionModel): out = clip_vision.encode_image(image).last_hidden_state else: pixel_values = clip_preprocess(image, size=224, crop=True).float() out = clip_vision.visual(pixel_values) return out # Code based on https://github.com/WikiChao/FreSca (MIT License) import torch import torch.fft as fft def fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): """ Apply frequency-dependent scaling to an image tensor using Fourier transforms. Parameters: x: Input tensor of shape (B, C, H, W) scale_low: Scaling factor for low-frequency components (default: 1.0) scale_high: Scaling factor for high-frequency components (default: 1.5) freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20) Returns: x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied. """ # Preserve input dtype and device dtype, device = x.dtype, x.device # Convert to float32 for FFT computations x = x.to(torch.float32) # 1) Apply FFT and shift low frequencies to center x_freq = fft.fftn(x, dim=(-2, -1)) x_freq = fft.fftshift(x_freq, dim=(-2, -1)) # 2) Create a mask to scale frequencies differently C, B, H, W = x_freq.shape crow, ccol = H // 2, W // 2 # Initialize mask with high-frequency scaling factor mask = torch.ones((C, B, H, W), device=device) * scale_high # Apply low-frequency scaling factor to center region mask[ ..., crow - freq_cutoff : crow + freq_cutoff, ccol - freq_cutoff : ccol + freq_cutoff, ] = scale_low # 3) Apply frequency-specific scaling x_freq = x_freq * mask # 4) Convert back to spatial domain x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real # 5) Restore original dtype x_filtered = x_filtered.to(dtype) return x_filtered