Upload config.py
Browse files
config.py
CHANGED
|
@@ -370,18 +370,18 @@ class DiffusionPipelineConfig:
|
|
| 370 |
print(">>> PIPELINE TYPE:", type(pipeline))
|
| 371 |
|
| 372 |
# Try to move each component using .to_empty()
|
| 373 |
-
for name in ["unet", "transformer", "vae", "text_encoder"]:
|
| 374 |
-
module = getattr(pipeline, name, None)
|
| 375 |
-
if isinstance(module, torch.nn.Module):
|
| 376 |
-
try:
|
| 377 |
-
print(f">>> Moving {name} to {device} using to_empty()")
|
| 378 |
-
module.to_empty(device)
|
| 379 |
-
except Exception as e:
|
| 380 |
-
print(f">>> WARNING: {name}.to_empty({device}) failed: {e}")
|
| 381 |
-
try:
|
| 382 |
-
print(f">>> Falling back to {name}.to({device})")
|
| 383 |
-
module.to(device)
|
| 384 |
-
except Exception as ee:
|
| 385 |
print(f">>> ERROR: {name}.to({device}) also failed: {ee}")
|
| 386 |
|
| 387 |
# Identify main model (for patching)
|
|
|
|
| 370 |
print(">>> PIPELINE TYPE:", type(pipeline))
|
| 371 |
|
| 372 |
# Try to move each component using .to_empty()
|
| 373 |
+
for name in ["unet", "transformer", "vae", "text_encoder"]:
|
| 374 |
+
module = getattr(pipeline, name, None)
|
| 375 |
+
if isinstance(module, torch.nn.Module):
|
| 376 |
+
try:
|
| 377 |
+
print(f">>> Moving {name} to {device} using to_empty()")
|
| 378 |
+
module.to_empty(device=device) # Use keyword argument
|
| 379 |
+
except Exception as e:
|
| 380 |
+
print(f">>> WARNING: {name}.to_empty({device}) failed: {e}")
|
| 381 |
+
try:
|
| 382 |
+
print(f">>> Falling back to {name}.to({device})")
|
| 383 |
+
module.to(device)
|
| 384 |
+
except Exception as ee:
|
| 385 |
print(f">>> ERROR: {name}.to({device}) also failed: {ee}")
|
| 386 |
|
| 387 |
# Identify main model (for patching)
|