Naphula commited on
Commit
8eacd9a
·
verified ·
1 Parent(s): 458fd98

Upload graph_v18.py

Browse files
Files changed (1) hide show
  1. graph_v18.py +775 -0
graph_v18.py ADDED
@@ -0,0 +1,775 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # graph_v18.py - Optimized for 3060 TI (8GB VRAM) and similar low-VRAM GPUs
2
+ # Copyright (C) 2025 Arcee AI
3
+ # SPDX-License-Identifier: LGPL-3.0-only
4
+ """
5
+ Module for computational graph execution.
6
+
7
+ Classes:
8
+ Task: Abstract base class representing a computational task.
9
+ Executor: Class for scheduling and executing directed acyclic task graphs.
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import gc
15
+ import logging
16
+ import networkx
17
+ import torch
18
+ import tqdm
19
+ from pydantic import BaseModel
20
+ from typing_extensions import Generic, TypeVar
21
+ from abc import ABC, abstractmethod
22
+ from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
23
+
24
+ from mergekit.common import get_torch_accelerator_module
25
+
26
+ # ============================================================================
27
+ # CONFIGURATION SECTION - TUNE THESE PARAMETERS FOR YOUR GPU
28
+ # ============================================================================
29
+
30
+ # --- PRIMARY VRAM TARGETS ---
31
+ # For 3060 TI (8GB): Start with 7.2-7.4GB. Increase if stable, decrease if OOM.
32
+ # For 3060 (12GB): Try 10.5-11.0GB
33
+ # For 4GB cards: Try 3.2-3.5GB
34
+ TARGET_VRAM_GB = 7.7 # Target VRAM usage in GB (TUNE THIS FIRST)
35
+
36
+ # Safety margin to account for PyTorch overhead and fragmentation
37
+ # Windows typically needs ~0.8GB, Linux ~0.5GB
38
+ VRAM_SAFETY_MARGIN_GB = 0.2 # Reduce to 0.5-0.6 on Linux, increase to 1.0 if unstable
39
+
40
+ # --- CUDA MEMORY ALLOCATOR CONFIGURATION ---
41
+ # Smaller values = less fragmentation but more overhead
42
+ # 24MB is optimal for 8GB cards, 32MB for 12GB+ cards
43
+ CUDA_MAX_SPLIT_SIZE_MB = 24 # Options: 16, 24, 32, 64
44
+
45
+ # --- CHUNK SIZE BEHAVIOR ---
46
+ # How aggressively to reduce chunk size on OOM (0.5-0.9 range)
47
+ # Lower = more conservative (slower but safer), Higher = more aggressive
48
+ CHUNK_REDUCTION_FACTOR = 0.75 # Options: 0.5 (safe), 0.7 (balanced), 0.85 (aggressive)
49
+
50
+ # Minimum chunk size before giving up and falling back to CPU
51
+ MIN_CHUNK_SIZE = 1 # Usually keep at 1, increase to 4-8 if seeing micro-chunk overhead
52
+
53
+ # Enable power-of-2 alignment for chunk sizes (following measure.py strategy)
54
+ # This improves memory allocation efficiency
55
+ ENABLE_POWER_OF_2_ALIGNMENT = True # Set False if causing issues
56
+
57
+ # --- TASK-SPECIFIC MEMORY MULTIPLIERS ---
58
+ # These control how much extra VRAM to reserve for specific task types
59
+ # Increase if task OOMs, decrease if underutilizing VRAM
60
+ TASK_MULTIPLIERS = {
61
+ "ModelStock": 2.2, # Options: 1.8-2.5 (needs room for pairwise similarities)
62
+ "Karcher": 3.0, # Options: 2.5-3.5 (iterative, needs working memory)
63
+ "Consensus": 3.0, # Options: 2.5-3.5 (similar to Karcher)
64
+ "default": 1.2, # Options: 1.0-1.5 (general tasks)
65
+ }
66
+
67
+ # --- MEMORY CLEANUP BEHAVIOR ---
68
+ # Enable aggressive garbage collection and cache clearing
69
+ # True = slower but more stable, False = faster but may fragment memory
70
+ ENABLE_AGGRESSIVE_CLEANUP = False # Set False if merges are very stable
71
+
72
+ # How often to force cleanup (every N tasks). 0 = after every task
73
+ CLEANUP_FREQUENCY = 10 # Options: 0 (always), 1, 2, 5, 10
74
+
75
+ # --- FALLBACK STRATEGY ---
76
+ # Fixed chunk sizes to try if adaptive chunking fails
77
+ # Powers of 2 work best for GPU memory alignment
78
+ FALLBACK_CHUNK_SIZES = [4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2]
79
+
80
+ # --- FAST PATH OPTIMIZATION ---
81
+ # Try to execute entire task at once before chunking
82
+ # True = faster when it works, False = always chunk (more conservative)
83
+ ENABLE_FAST_PATH = True # Set False if getting frequent OOM on large tasks
84
+
85
+ # --- TASK ROUTING ---
86
+ # Tasks that should always run on CPU (typically I/O bound)
87
+ CPU_ONLY_TASKS = [
88
+ "LoadTensor",
89
+ "GatherTensors",
90
+ "SaveTensor",
91
+ "TensorWriterTask",
92
+ "FinalizeModel",
93
+ "PermutedEmbeddings", # Gather operations don't benefit from GPU
94
+ ]
95
+
96
+ # ============================================================================
97
+ # END OF CONFIGURATION SECTION
98
+ # ============================================================================
99
+
100
+ if sys.platform == "win32":
101
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:{CUDA_MAX_SPLIT_SIZE_MB}"
102
+
103
+ ValueT = TypeVar("ValueT")
104
+ LOG = logging.getLogger(__name__)
105
+
106
+
107
+ def _round_to_power_of_2(n: int, prefer_lower: bool = True) -> int:
108
+ """Round to nearest power of 2 for memory alignment."""
109
+ if n <= 0:
110
+ return 1
111
+ if n == 1:
112
+ return 1
113
+
114
+ # Find the two nearest powers of 2
115
+ power = n.bit_length() - 1
116
+ lower = 1 << power
117
+ upper = 1 << (power + 1)
118
+
119
+ if prefer_lower or (n - lower) < (upper - n):
120
+ return lower
121
+ return upper
122
+
123
+
124
+ class Task(ABC, BaseModel, Generic[ValueT], frozen=True):
125
+ @abstractmethod
126
+ def arguments(self) -> Dict[str, "Task"]:
127
+ ...
128
+
129
+ @abstractmethod
130
+ def execute(self, **kwargs) -> ValueT:
131
+ ...
132
+
133
+ def priority(self) -> int:
134
+ return 0
135
+
136
+ def group_label(self) -> Optional[str]:
137
+ return None
138
+
139
+ def uses_accelerator(self) -> bool:
140
+ return False
141
+
142
+ def main_thread_only(self) -> bool:
143
+ return False
144
+
145
+ def duplicate_per_gpu(self) -> bool:
146
+ return False
147
+
148
+
149
+ class TaskUniverse:
150
+ tasks: List[Task]
151
+ task_to_index: Dict[Task, int]
152
+ task_arguments: Dict[int, Dict[str, int]]
153
+ _type_id_to_index: Dict[Tuple[type, int], int]
154
+
155
+ def __init__(self, tasks: Optional[Iterable[Task]] = None):
156
+ self.tasks = []
157
+ self.task_to_index = {}
158
+ self.task_arguments = {}
159
+ self._type_id_to_index = {}
160
+ if tasks is not None:
161
+ for task in tasks:
162
+ self.add_task(task)
163
+
164
+ def add_task(self, task: Task, recursive: bool = True) -> "TaskHandle":
165
+ _ti_key = (type(task), id(task))
166
+ if _ti_key in self._type_id_to_index:
167
+ index = self._type_id_to_index[_ti_key]
168
+ return TaskHandle(self, index)
169
+
170
+ index = self.task_to_index.setdefault(task, len(self.tasks))
171
+ if index < len(self.tasks):
172
+ return TaskHandle(self, index)
173
+ self.tasks.append(task)
174
+ self._type_id_to_index[_ti_key] = index
175
+
176
+ if recursive:
177
+ self.task_arguments[index] = {}
178
+ for k, v in task.arguments().items():
179
+ self.task_arguments[index][k] = self.add_task(v, recursive=True)._index
180
+ return TaskHandle(self, index)
181
+
182
+ def get_handle(self, task: Task) -> Optional["TaskHandle"]:
183
+ if task not in self.task_to_index:
184
+ return None
185
+ return TaskHandle(self, self.task_to_index[task])
186
+
187
+
188
+ class TaskHandle:
189
+ __slots__ = ["_universe", "_index"]
190
+ _universe: TaskUniverse
191
+ _index: int
192
+
193
+ def __init__(self, universe: TaskUniverse, index: int):
194
+ self._universe = universe
195
+ self._index = index
196
+
197
+ def task(self) -> Task:
198
+ return self._universe.tasks[self._index]
199
+
200
+ def arguments(self) -> Dict[str, "TaskHandle"]:
201
+ return {
202
+ k: TaskHandle(self._universe, v)
203
+ for k, v in self._universe.task_arguments[self._index].items()
204
+ }
205
+
206
+ def __eq__(self, other):
207
+ if not isinstance(other, TaskHandle):
208
+ return False
209
+ return self._index == other._index and self._universe is other._universe
210
+
211
+ def __hash__(self):
212
+ return self._index
213
+
214
+ def __str__(self):
215
+ return f"TaskHandle({type(self.task()).__name__}, {self._index})"
216
+
217
+ __repr__ = __str__
218
+
219
+
220
+ class ExecutionSchedule:
221
+ tasks: List[TaskHandle]
222
+ last_use_index: Dict[TaskHandle, int]
223
+
224
+ def __init__(self, tasks: List[TaskHandle], last_use_index: Dict[TaskHandle, int]):
225
+ self.tasks = tasks
226
+ self.last_use_index = last_use_index
227
+
228
+
229
+ def build_schedule(
230
+ targets: List[TaskHandle], cached_values: Dict[TaskHandle, Any]
231
+ ) -> ExecutionSchedule:
232
+ if not targets:
233
+ return ExecutionSchedule(tasks=[], last_use_index={})
234
+
235
+ universe = targets[0]._universe
236
+ dummy_handle = TaskHandle(universe, -1)
237
+ edge_tups: List[Tuple[TaskHandle, TaskHandle]] = []
238
+
239
+ explored = set()
240
+ to_explore = set(targets)
241
+ while to_explore:
242
+ task = to_explore.pop()
243
+ if task in explored:
244
+ continue
245
+ explored.add(task)
246
+ if task in (cached_values or {}):
247
+ continue
248
+ for dep in task.arguments().values():
249
+ to_explore.add(dep)
250
+ edge_tups.append((dep, task))
251
+
252
+ for target in targets:
253
+ edge_tups.append((dummy_handle, target))
254
+
255
+ def _compare_key(node: TaskHandle) -> Tuple[str, int]:
256
+ if node._index < 0:
257
+ return ("", 0)
258
+ task = node.task()
259
+ return (task.group_label() or "", -task.priority())
260
+
261
+ graph = networkx.DiGraph(edge_tups)
262
+ schedule: List[TaskHandle] = [
263
+ node
264
+ for node in networkx.lexicographical_topological_sort(graph, key=_compare_key)
265
+ if (node != dummy_handle) and node not in (cached_values or {})
266
+ ]
267
+
268
+ last_use_index = {}
269
+ for idx, task in reversed(list(enumerate(schedule))):
270
+ for dep in task.arguments().values():
271
+ if dep not in last_use_index:
272
+ last_use_index[dep] = idx
273
+ if task not in last_use_index:
274
+ last_use_index[task] = idx
275
+ for task in cached_values or {}:
276
+ if task not in last_use_index:
277
+ last_use_index[task] = len(schedule) + 1
278
+
279
+ return ExecutionSchedule(tasks=schedule, last_use_index=last_use_index)
280
+
281
+
282
+ class Executor:
283
+ math_device: torch.device
284
+ storage_device: torch.device
285
+ universe: TaskUniverse
286
+ targets: List[TaskHandle]
287
+ schedule: ExecutionSchedule
288
+ cached_values: Optional[Dict[TaskHandle, Any]]
289
+ _task_counter: int
290
+
291
+ def __init__(
292
+ self,
293
+ targets: Union[List[Task], List[TaskHandle]],
294
+ math_device: torch.device = torch.device("cpu"),
295
+ storage_device: torch.device = torch.device("cpu"),
296
+ cached_values: Optional[Dict[TaskHandle, Any]] = None,
297
+ ):
298
+ self.cached_values = cached_values
299
+ self._task_counter = 0
300
+
301
+ if isinstance(math_device, str):
302
+ math_device = torch.device(math_device)
303
+ if isinstance(storage_device, str):
304
+ storage_device = torch.device(storage_device)
305
+ self.math_device = math_device
306
+ self.storage_device = storage_device
307
+
308
+ if targets and isinstance(targets[0], Task):
309
+ universe = TaskUniverse(targets)
310
+ targets = [universe.add_task(t) for t in targets]
311
+ elif targets and isinstance(targets[0], TaskHandle):
312
+ universe = targets[0]._universe
313
+ elif not targets:
314
+ universe = TaskUniverse()
315
+ else:
316
+ raise ValueError("Targets must be a list of Task or TaskHandle instances")
317
+
318
+ self.universe = universe
319
+ self.targets = targets
320
+ self.schedule = build_schedule(targets, cached_values=cached_values)
321
+
322
+ def _slice_argument(self, arg: Any, start: int, end: int) -> Any:
323
+ """Recursively slice tensors within nested structures."""
324
+ if isinstance(arg, torch.Tensor):
325
+ if arg.shape[0] > 1:
326
+ return arg[start:end]
327
+ return arg
328
+ elif isinstance(arg, dict):
329
+ return {k: self._slice_argument(v, start, end) for k, v in arg.items()}
330
+ elif isinstance(arg, list):
331
+ return [self._slice_argument(v, start, end) for v in arg]
332
+ elif isinstance(arg, tuple):
333
+ return tuple(self._slice_argument(v, start, end) for v in arg)
334
+ return arg
335
+
336
+ def _get_memory_stats(self) -> Dict[str, float]:
337
+ """Get current VRAM statistics in GB."""
338
+ if self.math_device.type != "cuda":
339
+ return {}
340
+
341
+ allocated = torch.cuda.memory_allocated(self.math_device) / (1024**3)
342
+ reserved = torch.cuda.memory_reserved(self.math_device) / (1024**3)
343
+ total = torch.cuda.get_device_properties(self.math_device).total_memory / (1024**3)
344
+
345
+ return {
346
+ "allocated_gb": allocated,
347
+ "reserved_gb": reserved,
348
+ "total_gb": total,
349
+ "free_gb": total - allocated,
350
+ }
351
+
352
+ def _get_adaptive_chunk_size(self, task: Task, arguments: Dict[str, Any]) -> int:
353
+ """
354
+ Calculate optimal chunk size based on available VRAM and task requirements.
355
+
356
+ This implements the "measure.py strategy" of targeting a specific VRAM fill level
357
+ rather than using currently available memory, which prevents oscillation.
358
+ """
359
+ if self.math_device.type == "cpu":
360
+ return 1024 # Large default for CPU
361
+
362
+ # Get hardware capacity
363
+ total_vram = torch.cuda.get_device_properties(self.math_device).total_memory
364
+ target_bytes = TARGET_VRAM_GB * (1024**3)
365
+
366
+ # Analyze tensor dimensions and count
367
+ num_tensors = 0
368
+ width = 0
369
+ bytes_per_element = 4 # Default float32
370
+
371
+ for arg in arguments.values():
372
+ if isinstance(arg, torch.Tensor):
373
+ num_tensors += 1
374
+ width = max(width, arg.shape[-1] if len(arg.shape) > 1 else arg.shape[0])
375
+ bytes_per_element = arg.element_size()
376
+ elif isinstance(arg, dict):
377
+ for v in arg.values():
378
+ if isinstance(v, torch.Tensor):
379
+ num_tensors += 1
380
+ width = max(width, v.shape[-1] if len(v.shape) > 1 else v.shape[0])
381
+ bytes_per_element = v.element_size()
382
+
383
+ if num_tensors == 0 or width == 0:
384
+ return 512 # Safe default
385
+
386
+ # Get task-specific multiplier
387
+ task_name = type(task).__name__
388
+ multiplier = TASK_MULTIPLIERS.get("default", 1.2)
389
+
390
+ for key, mult in TASK_MULTIPLIERS.items():
391
+ if key in task_name:
392
+ multiplier = mult
393
+ break
394
+
395
+ # Calculate bytes per row with multiplier for working memory
396
+ bytes_per_row = num_tensors * width * bytes_per_element * multiplier
397
+
398
+ # Calculate usable VRAM (target minus current allocation and safety margin)
399
+ current_allocated = torch.cuda.memory_allocated(self.math_device)
400
+ safety_bytes = VRAM_SAFETY_MARGIN_GB * (1024**3)
401
+ usable_vram = max(target_bytes - current_allocated - safety_bytes, 1024 * (1024**2))
402
+
403
+ # Calculate chunk size
404
+ chunk_size = max(MIN_CHUNK_SIZE, int(usable_vram // bytes_per_row))
405
+
406
+ # Apply power-of-2 alignment if enabled (measure.py strategy)
407
+ if ENABLE_POWER_OF_2_ALIGNMENT and chunk_size > MIN_CHUNK_SIZE:
408
+ chunk_size = _round_to_power_of_2(chunk_size, prefer_lower=True)
409
+
410
+ LOG.debug(f"Calculated chunk size: {chunk_size} (tensors={num_tensors}, width={width}, mult={multiplier:.2f})")
411
+ return chunk_size
412
+
413
+ def _execute_chunked(self, task: Task, arguments: Dict[str, Any]) -> Any:
414
+ """
415
+ Execute task in chunks with progressive fallback strategy.
416
+
417
+ Strategy:
418
+ 1. Try adaptive chunk size
419
+ 2. On OOM, reduce by CHUNK_REDUCTION_FACTOR
420
+ 3. Continue until success or MIN_CHUNK_SIZE reached
421
+ """
422
+ # Find total rows to process
423
+ total_rows = 0
424
+ for arg in arguments.values():
425
+ if isinstance(arg, torch.Tensor):
426
+ total_rows = arg.shape[0]
427
+ break
428
+ elif isinstance(arg, dict):
429
+ for v in arg.values():
430
+ if isinstance(v, torch.Tensor):
431
+ total_rows = v.shape[0]
432
+ break
433
+ if total_rows > 0:
434
+ break
435
+
436
+ if total_rows == 0:
437
+ return task.execute(**arguments)
438
+
439
+ # Calculate initial chunk size
440
+ chunk_size = self._get_adaptive_chunk_size(task, arguments)
441
+
442
+ # FAST PATH: Try to execute all at once if chunk size >= total rows
443
+ if ENABLE_FAST_PATH and chunk_size >= total_rows:
444
+ try:
445
+ gpu_args = {
446
+ k: self._move_tensors(v, self.math_device)
447
+ for k, v in arguments.items()
448
+ }
449
+ res = task.execute(**gpu_args)
450
+ result = self._move_tensors(res, self.storage_device)
451
+ del gpu_args, res
452
+ if ENABLE_AGGRESSIVE_CLEANUP:
453
+ torch.cuda.empty_cache()
454
+ return result
455
+ except torch.OutOfMemoryError:
456
+ LOG.warning(f"Fast path OOM, falling back to chunking")
457
+ torch.cuda.empty_cache()
458
+ gc.collect()
459
+ chunk_size = max(MIN_CHUNK_SIZE, total_rows // 2)
460
+
461
+ # Chunked execution with progressive reduction
462
+ results = []
463
+ i = 0
464
+ oom_count = 0
465
+
466
+ while i < total_rows:
467
+ end = min(i + chunk_size, total_rows)
468
+
469
+ try:
470
+ chunk_args_gpu = {
471
+ k: self._move_tensors(self._slice_argument(v, i, end), self.math_device)
472
+ for k, v in arguments.items()
473
+ }
474
+ chunk_res = task.execute(**chunk_args_gpu)
475
+ results.append(self._move_tensors(chunk_res, self.storage_device))
476
+
477
+ del chunk_args_gpu, chunk_res
478
+
479
+ # Aggressive cleanup per measure.py strategy
480
+ if ENABLE_AGGRESSIVE_CLEANUP:
481
+ torch.cuda.empty_cache()
482
+
483
+ i = end # Move to next chunk
484
+ oom_count = 0 # Reset OOM counter on success
485
+
486
+ except torch.OutOfMemoryError:
487
+ oom_count += 1
488
+ torch.cuda.empty_cache()
489
+ gc.collect()
490
+
491
+ # Progressive reduction
492
+ old_chunk = chunk_size
493
+ chunk_size = max(MIN_CHUNK_SIZE, int(chunk_size * CHUNK_REDUCTION_FACTOR))
494
+
495
+ # Apply power-of-2 alignment
496
+ if ENABLE_POWER_OF_2_ALIGNMENT:
497
+ chunk_size = _round_to_power_of_2(chunk_size, prefer_lower=True)
498
+
499
+ if chunk_size < MIN_CHUNK_SIZE:
500
+ LOG.error(f"Chunk size below minimum ({MIN_CHUNK_SIZE}), cannot continue")
501
+ raise
502
+
503
+ LOG.warning(
504
+ f"OOM at chunk {old_chunk}, reducing to {chunk_size} "
505
+ f"(attempt {oom_count}, progress: {i}/{total_rows})"
506
+ )
507
+
508
+ # Safety: if we OOM too many times, something is wrong
509
+ if oom_count > 10:
510
+ LOG.error("Too many OOM errors, giving up")
511
+ raise
512
+
513
+ # Concatenate results
514
+ if not results:
515
+ return None
516
+
517
+ if isinstance(results[0], torch.Tensor):
518
+ return torch.cat(results, dim=0)
519
+ elif isinstance(results[0], dict):
520
+ out = {}
521
+ for k in results[0].keys():
522
+ out[k] = torch.cat([r[k] for r in results], dim=0)
523
+ return out
524
+
525
+ return results
526
+
527
+ def _execute_with_fallback(self, task: Task, arguments: Dict[str, Any], accelerator) -> Any:
528
+ """
529
+ Execute task with comprehensive fallback strategy.
530
+
531
+ Strategy:
532
+ 1. Try full GPU execution
533
+ 2. Try adaptive chunking
534
+ 3. Try fixed chunk sizes
535
+ 4. Fall back to CPU
536
+ """
537
+ task_name = type(task).__name__
538
+
539
+ # Strategy 1: Try full GPU execution for light tasks
540
+ try:
541
+ gpu_args = {
542
+ k: self._move_tensors(v, self.math_device)
543
+ for k, v in arguments.items()
544
+ }
545
+ res = task.execute(**gpu_args)
546
+ result = self._move_tensors(res, self.storage_device)
547
+ del gpu_args, res
548
+ return result
549
+ except torch.OutOfMemoryError:
550
+ LOG.debug(f"Full GPU execution failed for {task_name}, trying chunked")
551
+ torch.cuda.empty_cache()
552
+ gc.collect()
553
+ except Exception as e:
554
+ LOG.warning(f"GPU execution error for {task_name}: {e}")
555
+ torch.cuda.empty_cache()
556
+ raise
557
+
558
+ # Strategy 2: Try adaptive chunking
559
+ try:
560
+ result = self._execute_chunked(task, arguments)
561
+ return result
562
+ except torch.OutOfMemoryError:
563
+ LOG.warning(f"Adaptive chunking failed for {task_name}, trying fixed sizes")
564
+ torch.cuda.empty_cache()
565
+ gc.collect()
566
+ except Exception as e:
567
+ LOG.warning(f"Chunking error for {task_name}: {e}")
568
+ raise
569
+
570
+ # Strategy 3: Try fixed chunk sizes
571
+ for chunk_size in FALLBACK_CHUNK_SIZES:
572
+ if chunk_size < MIN_CHUNK_SIZE:
573
+ continue
574
+
575
+ try:
576
+ LOG.info(f"Trying fixed chunk size {chunk_size} for {task_name}")
577
+
578
+ # Get total rows
579
+ total_rows = 0
580
+ for arg in arguments.values():
581
+ if isinstance(arg, torch.Tensor):
582
+ total_rows = arg.shape[0]
583
+ break
584
+ elif isinstance(arg, dict):
585
+ for v in arg.values():
586
+ if isinstance(v, torch.Tensor):
587
+ total_rows = v.shape[0]
588
+ break
589
+ if total_rows > 0:
590
+ break
591
+
592
+ if total_rows == 0:
593
+ break
594
+
595
+ results = []
596
+ for i in range(0, total_rows, chunk_size):
597
+ end = min(i + chunk_size, total_rows)
598
+ chunk_args = {
599
+ k: self._slice_argument(v, i, end)
600
+ for k, v in arguments.items()
601
+ }
602
+ chunk_args_gpu = {
603
+ k: self._move_tensors(v, self.math_device)
604
+ for k, v in chunk_args.items()
605
+ }
606
+ chunk_res = task.execute(**chunk_args_gpu)
607
+ results.append(self._move_tensors(chunk_res, self.storage_device))
608
+ del chunk_args, chunk_args_gpu, chunk_res
609
+
610
+ if ENABLE_AGGRESSIVE_CLEANUP:
611
+ torch.cuda.empty_cache()
612
+
613
+ if isinstance(results[0], torch.Tensor):
614
+ return torch.cat(results, dim=0)
615
+ elif isinstance(results[0], dict):
616
+ out = {}
617
+ for k in results[0].keys():
618
+ out[k] = torch.cat([r[k] for r in results], dim=0)
619
+ return out
620
+ return results
621
+
622
+ except torch.OutOfMemoryError:
623
+ torch.cuda.empty_cache()
624
+ gc.collect()
625
+ continue
626
+ except Exception as e:
627
+ LOG.warning(f"Fixed chunk {chunk_size} failed: {e}")
628
+ break
629
+
630
+ # Strategy 4: CPU fallback
631
+ LOG.warning(f"All GPU strategies failed for {task_name}, using CPU")
632
+ raise torch.OutOfMemoryError("Forcing CPU fallback")
633
+
634
+ def _run(
635
+ self,
636
+ quiet: bool = False,
637
+ desc: Optional[str] = None,
638
+ ) -> Iterator[Tuple[TaskHandle, Any]]:
639
+ last_use_index = self.schedule.last_use_index
640
+
641
+ values: Dict[TaskHandle, Any] = {}
642
+ if self.cached_values:
643
+ for task, value in self.cached_values.items():
644
+ values[task] = value
645
+
646
+ is_gpu_execution = self.math_device.type != "cpu"
647
+ accelerator = get_torch_accelerator_module(self.math_device.type) if is_gpu_execution else None
648
+
649
+ for idx, task_handle in (
650
+ pbar := tqdm.tqdm(
651
+ list(enumerate(self.schedule.tasks)),
652
+ disable=quiet,
653
+ desc=desc or "Executing graph",
654
+ )
655
+ ):
656
+ task = task_handle.task()
657
+ task_type = type(task).__name__
658
+
659
+ # Log memory stats periodically
660
+ if is_gpu_execution and idx % 10 == 0:
661
+ stats = self._get_memory_stats()
662
+ LOG.debug(
663
+ f"Memory: {stats.get('allocated_gb', 0):.2f}GB allocated, "
664
+ f"{stats.get('free_gb', 0):.2f}GB free of {stats.get('total_gb', 0):.2f}GB"
665
+ )
666
+
667
+ # Determine execution strategy
668
+ is_cpu_only_task = task_type in CPU_ONLY_TASKS
669
+ want_gpu = is_gpu_execution and task.uses_accelerator() and not is_cpu_only_task
670
+
671
+ # Collect arguments
672
+ arguments = {k: values[h] for k, h in task_handle.arguments().items()}
673
+
674
+ success = False
675
+
676
+ # Try GPU execution
677
+ if want_gpu:
678
+ try:
679
+ res = self._execute_with_fallback(task, arguments, accelerator)
680
+ values[task_handle] = res
681
+ success = True
682
+ except torch.OutOfMemoryError:
683
+ LOG.warning(f"All GPU strategies exhausted for {task_type}, falling back to CPU")
684
+ success = False
685
+ except Exception as e:
686
+ LOG.error(f"GPU execution failed for {task_type}: {e}")
687
+ success = False
688
+
689
+ # Cleanup after GPU attempt
690
+ if is_gpu_execution and ENABLE_AGGRESSIVE_CLEANUP:
691
+ gc.collect()
692
+ if accelerator:
693
+ accelerator.empty_cache()
694
+
695
+ # CPU fallback
696
+ if not success:
697
+ if want_gpu:
698
+ LOG.info(f"Executing {task_type} on CPU")
699
+
700
+ # Ensure cleanup before CPU execution
701
+ if is_gpu_execution:
702
+ gc.collect()
703
+ if accelerator:
704
+ accelerator.empty_cache()
705
+
706
+ # Move arguments to CPU
707
+ cpu_arguments = {
708
+ k: self._move_tensors(v, torch.device("cpu"))
709
+ for k, v in arguments.items()
710
+ }
711
+
712
+ res = task.execute(**cpu_arguments)
713
+ del cpu_arguments
714
+ res = self._move_tensors(res, self.storage_device)
715
+ values[task_handle] = res
716
+
717
+ del res
718
+ del arguments
719
+
720
+ if task_handle in self.targets:
721
+ yield (task_handle, values[task_handle])
722
+
723
+ # Evict unreferenced values
724
+ expired = []
725
+ for key in values:
726
+ if idx >= last_use_index[key]:
727
+ expired.append(key)
728
+ for key in expired:
729
+ del values[key]
730
+
731
+ # Periodic cleanup (measure.py strategy)
732
+ self._task_counter += 1
733
+ if is_gpu_execution and ENABLE_AGGRESSIVE_CLEANUP:
734
+ if CLEANUP_FREQUENCY == 0 or self._task_counter % max(1, CLEANUP_FREQUENCY) == 0:
735
+ gc.collect()
736
+ if accelerator:
737
+ accelerator.empty_cache()
738
+
739
+ del values
740
+ del pbar
741
+
742
+ def run(
743
+ self,
744
+ quiet: bool = False,
745
+ desc: Optional[str] = None,
746
+ ) -> Iterator[Tuple[Task, Any]]:
747
+ for handle, value in self._run(quiet=quiet, desc=desc):
748
+ yield (handle.task(), value)
749
+
750
+ def execute(self, desc: Optional[str] = None) -> None:
751
+ for _ in self.run(desc=desc):
752
+ pass
753
+
754
+ def _move_tensors(
755
+ self, value: Any, device: torch.device, non_blocking: Optional[bool] = None
756
+ ) -> Any:
757
+ """Move tensors to specified device, handling nested structures."""
758
+ if non_blocking is None:
759
+ non_blocking = device.type in ["cuda", "xpu"]
760
+
761
+ if isinstance(value, torch.Tensor):
762
+ if value.device == device:
763
+ return value
764
+ return value.to(device=device, non_blocking=non_blocking)
765
+ elif isinstance(value, dict):
766
+ return {
767
+ k: self._move_tensors(v, device, non_blocking)
768
+ for k, v in value.items()
769
+ }
770
+ elif isinstance(value, list):
771
+ return [self._move_tensors(v, device, non_blocking) for v in value]
772
+ elif isinstance(value, tuple):
773
+ return tuple(self._move_tensors(v, device, non_blocking) for v in value)
774
+
775
+ return value