# Copyright (C) 2025 Arcee AI # SPDX-License-Identifier: LGPL-3.0-only """ Module for computational graph execution. Classes: Task: Abstract base class representing a computational task. Executor: Class for scheduling and executing directed acyclic task graphs. """ import os import sys import gc import logging import networkx import torch import tqdm from pydantic import BaseModel from typing_extensions import Generic, TypeVar from abc import ABC, abstractmethod from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union from mergekit.common import get_torch_accelerator_module # Windows/NVIDIA specific allocator tuning to reduce fragmentation if sys.platform == "win32": os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32" ValueT = TypeVar("ValueT") LOG = logging.getLogger(__name__) class Task(ABC, BaseModel, Generic[ValueT], frozen=True): @abstractmethod def arguments(self) -> Dict[str, "Task"]: ... @abstractmethod def execute(self, **kwargs) -> ValueT: ... def priority(self) -> int: return 0 def group_label(self) -> Optional[str]: return None def uses_accelerator(self) -> bool: return False def main_thread_only(self) -> bool: return False def duplicate_per_gpu(self) -> bool: return False class TaskUniverse: tasks: List[Task] task_to_index: Dict[Task, int] task_arguments: Dict[int, Dict[str, int]] _type_id_to_index: Dict[Tuple[type, int], int] def __init__(self, tasks: Optional[Iterable[Task]] = None): self.tasks = [] self.task_to_index = {} self.task_arguments = {} self._type_id_to_index = {} if tasks is not None: for task in tasks: self.add_task(task) def add_task(self, task: Task, recursive: bool = True) -> "TaskHandle": _ti_key = (type(task), id(task)) if _ti_key in self._type_id_to_index: index = self._type_id_to_index[_ti_key] return TaskHandle(self, index) index = self.task_to_index.setdefault(task, len(self.tasks)) if index < len(self.tasks): return TaskHandle(self, index) self.tasks.append(task) self._type_id_to_index[_ti_key] = index if recursive: self.task_arguments[index] = {} for k, v in task.arguments().items(): self.task_arguments[index][k] = self.add_task(v, recursive=True)._index return TaskHandle(self, index) def get_handle(self, task: Task) -> Optional["TaskHandle"]: if task not in self.task_to_index: return None return TaskHandle(self, self.task_to_index[task]) class TaskHandle: __slots__ = ["_universe", "_index"] _universe: TaskUniverse _index: int def __init__(self, universe: TaskUniverse, index: int): self._universe = universe self._index = index def task(self) -> Task: return self._universe.tasks[self._index] def arguments(self) -> Dict[str, "TaskHandle"]: return { k: TaskHandle(self._universe, v) for k, v in self._universe.task_arguments[self._index].items() } def __eq__(self, other): if not isinstance(other, TaskHandle): return False return self._index == other._index and self._universe is other._universe def __hash__(self): return self._index def __str__(self): return f"TaskHandle({type(self.task()).__name__}, {self._index})" __repr__ = __str__ class ExecutionSchedule: tasks: List[TaskHandle] last_use_index: Dict[TaskHandle, int] def __init__(self, tasks: List[TaskHandle], last_use_index: Dict[TaskHandle, int]): self.tasks = tasks self.last_use_index = last_use_index def build_schedule( targets: List[TaskHandle], cached_values: Dict[TaskHandle, Any] ) -> ExecutionSchedule: if not targets: return ExecutionSchedule(tasks=[], last_use_index={}) universe = targets[0]._universe dummy_handle = TaskHandle(universe, -1) edge_tups: List[Tuple[TaskHandle, TaskHandle]] = [] explored = set() to_explore = set(targets) while to_explore: task = to_explore.pop() if task in explored: continue explored.add(task) if task in (cached_values or {}): continue for dep in task.arguments().values(): to_explore.add(dep) edge_tups.append((dep, task)) for target in targets: edge_tups.append((dummy_handle, target)) def _compare_key(node: TaskHandle) -> Tuple[str, int]: if node._index < 0: return ("", 0) task = node.task() return (task.group_label() or "", -task.priority()) graph = networkx.DiGraph(edge_tups) schedule: List[TaskHandle] = [ node for node in networkx.lexicographical_topological_sort(graph, key=_compare_key) if (node != dummy_handle) and node not in (cached_values or {}) ] last_use_index = {} for idx, task in reversed(list(enumerate(schedule))): for dep in task.arguments().values(): if dep not in last_use_index: last_use_index[dep] = idx if task not in last_use_index: last_use_index[task] = idx for task in cached_values or {}: if task not in last_use_index: last_use_index[task] = len(schedule) + 1 return ExecutionSchedule(tasks=schedule, last_use_index=last_use_index) class Executor: math_device: torch.device storage_device: torch.device universe: TaskUniverse targets: List[TaskHandle] schedule: ExecutionSchedule cached_values: Optional[Dict[TaskHandle, Any]] def __init__( self, targets: Union[List[Task], List[TaskHandle]], math_device: torch.device = torch.device("cpu"), storage_device: torch.device = torch.device("cpu"), cached_values: Optional[Dict[TaskHandle, Any]] = None, ): self.cached_values = cached_values if isinstance(math_device, str): math_device = torch.device(math_device) if isinstance(storage_device, str): storage_device = torch.device(storage_device) self.math_device = math_device self.storage_device = storage_device if targets and isinstance(targets[0], Task): universe = TaskUniverse(targets) targets = [universe.add_task(t) for t in targets] elif targets and isinstance(targets[0], TaskHandle): universe = targets[0]._universe elif not targets: universe = TaskUniverse() else: raise ValueError("Targets must be a list of Task or TaskHandle instances") self.universe = universe self.targets = targets self.schedule = build_schedule(targets, cached_values=cached_values) def _slice_argument(self, arg: Any, start: int, end: int) -> Any: """Helper to slice tensors within nested structures.""" if isinstance(arg, torch.Tensor): # Only slice if the dimension is large enough if arg.shape[0] > 1: return arg[start:end] return arg elif isinstance(arg, dict): return {k: self._slice_argument(v, start, end) for k, v in arg.items()} elif isinstance(arg, list): return [self._slice_argument(v, start, end) for v in arg] elif isinstance(arg, tuple): return tuple(self._slice_argument(v, start, end) for v in arg) return arg def _execute_chunked(self, task: Task, arguments: Dict[str, Any], chunk_size: int) -> Any: """ Executes a task by splitting input tensors into chunks, processing on GPU, and concatenating results on CPU. """ # Find a reference tensor to determine batch size ref_tensor = None for arg in arguments.values(): if isinstance(arg, torch.Tensor): ref_tensor = arg break elif isinstance(arg, dict): for v in arg.values(): if isinstance(v, torch.Tensor): ref_tensor = v break if ref_tensor is not None: break if ref_tensor is None: raise ValueError("No tensors found to chunk") total_rows = ref_tensor.shape[0] results = [] accelerator = get_torch_accelerator_module(self.math_device.type) if self.math_device.type != "cpu" else None # Process in chunks for i in range(0, total_rows, chunk_size): end = min(i + chunk_size, total_rows) # Slice inputs chunk_args = { k: self._slice_argument(v, i, end) for k, v in arguments.items() } # Move chunk inputs to GPU chunk_args_gpu = { k: self._move_tensors(v, self.math_device) for k, v in chunk_args.items() } # Execute chunk_res = task.execute(**chunk_args_gpu) # Move result to CPU immediately chunk_res_cpu = self._move_tensors(chunk_res, self.storage_device) results.append(chunk_res_cpu) # Cleanup del chunk_args del chunk_args_gpu del chunk_res # Clear cache inside loop to handle complex methods like Magic if accelerator: accelerator.empty_cache() # Concatenate results if isinstance(results[0], torch.Tensor): return torch.cat(results, dim=0) elif isinstance(results[0], dict): # Reassemble dict of tensors out = {} for k in results[0].keys(): out[k] = torch.cat([r[k] for r in results], dim=0) return out else: raise ValueError("Unsupported return type for chunking") def _run( self, quiet: bool = False, desc: Optional[str] = None, ) -> Iterator[Tuple[TaskHandle, Any]]: last_use_index = self.schedule.last_use_index values: Dict[TaskHandle, Any] = {} if self.cached_values: for task, value in self.cached_values.items(): values[task] = value is_gpu_execution = self.math_device.type != "cpu" accelerator = get_torch_accelerator_module(self.math_device.type) if is_gpu_execution else None for idx, task_handle in ( pbar := tqdm.tqdm( list(enumerate(self.schedule.tasks)), disable=quiet, desc=desc or "Executing graph", ) ): task = task_handle.task() task_type = type(task).__name__ # Heuristic: Don't force I/O tasks to GPU # PermutedEmbeddings is essentially a gather operation, hard to chunk, better on CPU if memory is tight is_io_task = task_type in ["LoadTensor", "GatherTensors", "SaveTensor", "TensorWriterTask", "FinalizeModel", "PermutedEmbeddings"] want_gpu = is_gpu_execution and (task.uses_accelerator() or not is_io_task) success = False if want_gpu: try: # 1. Try Full GPU Execution arguments = {} for name, dep_handle in task_handle.arguments().items(): value = values[dep_handle] value = self._move_tensors(value, self.math_device) arguments[name] = value res = task.execute(**arguments) del arguments res = self._move_tensors(res, self.storage_device) values[task_handle] = res success = True except torch.OutOfMemoryError: # Cleanup arguments = None res = None gc.collect() if accelerator: accelerator.empty_cache() # 2. Try Chunked GPU Execution with Adaptive Sizing chunk_sizes = [4096, 2048, 1024, 512, 256, 128, 64] # Reload arguments on CPU arguments = {} for name, dep_handle in task_handle.arguments().items(): arguments[name] = values[dep_handle] # Already on storage device for chunk_size in chunk_sizes: try: LOG.info(f"OOM on {task_type}. Attempting chunked GPU execution (size={chunk_size})...") res = self._execute_chunked(task, arguments, chunk_size=chunk_size) values[task_handle] = res success = True LOG.info(f"Chunked execution successful for {task_type} (size={chunk_size})") break except Exception as e: LOG.warning(f"Chunked execution failed at size {chunk_size} ({str(e)}).") gc.collect() if accelerator: accelerator.empty_cache() # If it wasn't an OOM (e.g. index error), stop trying chunking if not isinstance(e, torch.OutOfMemoryError): break # 3. CPU Fallback if not success: if want_gpu: LOG.warning(f"All GPU attempts failed for {task_type}. Falling back to CPU.") # Ensure we clean up any GPU debris before CPU attempt if is_gpu_execution: gc.collect() if accelerator: accelerator.empty_cache() arguments = {} for name, dep_handle in task_handle.arguments().items(): value = values[dep_handle] value = self._move_tensors(value, torch.device("cpu")) arguments[name] = value res = task.execute(**arguments) del arguments res = self._move_tensors(res, self.storage_device) values[task_handle] = res del res if task_handle in self.targets: yield (task_handle, values[task_handle]) # Evict unreferenced values expired = [] for key in values: if idx >= last_use_index[key]: expired.append(key) for key in expired: del values[key] # Aggressive cleanup if is_gpu_execution: gc.collect() if accelerator: accelerator.empty_cache() del values del pbar def run( self, quiet: bool = False, desc: Optional[str] = None, ) -> Iterator[Tuple[Task, Any]]: for handle, value in self._run(quiet=quiet, desc=desc): yield (handle.task(), value) def execute(self, desc: Optional[str] = None) -> None: for _ in self.run(desc=desc): pass def _move_tensors( self, value: Any, device: torch.device, non_blocking: Optional[bool] = None ) -> Any: if non_blocking is None: non_blocking = device.type in ["cuda", "xpu"] if isinstance(value, torch.Tensor): if value.device == device: return value return value.to(device=device, non_blocking=non_blocking) elif isinstance(value, dict): return {k: self._move_tensors(v, device, non_blocking) for k, v in value.items()} elif isinstance(value, list): return [self._move_tensors(v, device, non_blocking) for v in value] elif isinstance(value, tuple): return tuple(self._move_tensors(v, device, non_blocking) for v in value) return value