selfit-camera commited on
Commit
164deee
·
1 Parent(s): 4fa21de
Files changed (6) hide show
  1. __lib__/app.py +0 -0
  2. __lib__/i18n/en.pyc +0 -0
  3. __lib__/nfsw.pyc +0 -0
  4. __lib__/util.pyc +0 -0
  5. pipeline.py +206 -32
  6. scheduling_omni.py +634 -0
__lib__/app.py CHANGED
The diff for this file is too large to render. See raw diff
 
__lib__/i18n/en.pyc CHANGED
Binary files a/__lib__/i18n/en.pyc and b/__lib__/i18n/en.pyc differ
 
__lib__/nfsw.pyc CHANGED
Binary files a/__lib__/nfsw.pyc and b/__lib__/nfsw.pyc differ
 
__lib__/util.pyc CHANGED
Binary files a/__lib__/util.pyc and b/__lib__/util.pyc differ
 
pipeline.py CHANGED
@@ -17,6 +17,7 @@ from diffusers import DiffusionPipeline, DDIMScheduler
17
  from diffusers.configuration_utils import ConfigMixin, register_to_config
18
  from diffusers.models.modeling_utils import ModelMixin
19
  from diffusers.utils import BaseOutput
 
20
 
21
  # Optimization imports
22
  try:
@@ -729,6 +730,86 @@ class OmniMMDitV2(ModelMixin, PreTrainedModel):
729
 
730
  return BaseOutput(sample=output)
731
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
732
  # -----------------------------------------------------------------------------
733
  # 5. The "Fancy" Pipeline
734
  # -----------------------------------------------------------------------------
@@ -744,7 +825,7 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
744
  tokenizer: CLIPTokenizer
745
  text_encoder: CLIPTextModel
746
  vae: Any # AutoencoderKL
747
- scheduler: DDIMScheduler
748
 
749
  _optional_components = ["visual_encoder"]
750
 
@@ -754,7 +835,7 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
754
  vae: Any,
755
  text_encoder: CLIPTextModel,
756
  tokenizer: CLIPTokenizer,
757
- scheduler: DDIMScheduler,
758
  visual_encoder: Optional[Any] = None,
759
  ):
760
  super().__init__()
@@ -792,6 +873,12 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
792
 
793
  self._is_compiled = False
794
  self._is_fp8_enabled = False
 
 
 
 
 
 
795
 
796
  def enable_fp8_quantization(self):
797
  """Enable FP8 quantization for faster inference"""
@@ -860,6 +947,29 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
860
 
861
  return self
862
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863
  @torch.no_grad()
864
  def __call__(
865
  self,
@@ -880,6 +990,7 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
880
  callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
881
  callback_steps: int = 1,
882
  use_optimized_inference: bool = True,
 
883
  **kwargs,
884
  ):
885
  # Use optimized inference context
@@ -905,6 +1016,7 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
905
  return_dict=return_dict,
906
  callback=callback,
907
  callback_steps=callback_steps,
 
908
  **kwargs,
909
  )
910
 
@@ -926,6 +1038,7 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
926
  return_dict: bool = True,
927
  callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
928
  callback_steps: int = 1,
 
929
  **kwargs,
930
  ):
931
  # Validate and set default dimensions
@@ -977,8 +1090,15 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
977
  visual_embeddings_list.append(vis_emb)
978
 
979
  # Prepare timesteps
980
- self.scheduler.set_timesteps(num_inference_steps, device=self.device)
981
- timesteps = self.scheduler.timesteps
 
 
 
 
 
 
 
982
 
983
  # Initialize latent space
984
  num_channels_latents = self.model.config.in_channels
@@ -989,34 +1109,88 @@ class OmniMMDitV2Pipeline(DiffusionPipeline):
989
  latents = torch.randn(shape, generator=generator, device=self.device, dtype=text_embeddings.dtype)
990
  latents = latents * self.scheduler.init_noise_sigma
991
 
992
- # Denoising loop with optimizations
993
- with self.progress_bar(total=num_inference_steps) as progress_bar:
994
- for i, t in enumerate(timesteps):
995
- latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
996
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
997
-
998
- # Use mixed precision autocast
999
- with self.model_optimizer.autocast_context():
1000
- noise_pred = self.model(
1001
- hidden_states=latent_model_input,
1002
- timestep=t,
1003
- encoder_hidden_states=torch.cat([text_embeddings] * 2),
1004
- visual_conditions=visual_embeddings_list * 2 if visual_embeddings_list else None,
1005
- video_frames=num_frames
1006
- ).sample
1007
-
1008
- # Apply classifier-free guidance
1009
- if guidance_scale > 1.0:
1010
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1011
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1012
-
1013
- latents = self.scheduler.step(noise_pred, t, latents, eta=eta).prev_sample
1014
-
1015
- # Call callback if provided
1016
- if callback is not None and i % callback_steps == 0:
1017
- callback(i, t, latents)
1018
-
1019
- progress_bar.update()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1020
 
1021
  # Decode latents with proper post-processing
1022
  if output_type == "latent":
 
17
  from diffusers.configuration_utils import ConfigMixin, register_to_config
18
  from diffusers.models.modeling_utils import ModelMixin
19
  from diffusers.utils import BaseOutput
20
+ from .scheduling_omni import OmniScheduler
21
 
22
  # Optimization imports
23
  try:
 
730
 
731
  return BaseOutput(sample=output)
732
 
733
+ # -----------------------------------------------------------------------------
734
+ # 4.5 π-Flow Policy Network (coarse trajectory predictor)
735
+ # -----------------------------------------------------------------------------
736
+
737
+
738
+ class PiFlowPolicyNetwork(nn.Module):
739
+ """
740
+ Lightweight π-Flow policy network: predicts multi-step velocity trajectories in one forward pass for few-step sampling.
741
+ Relies only on text/visual global aggregated features + time embeddings and outputs velocity fields matching latent shape.
742
+ """
743
+
744
+ def __init__(
745
+ self,
746
+ text_hidden_size: int,
747
+ visual_embed_dim: int,
748
+ latent_channels: int,
749
+ hidden_size: int = 1024,
750
+ ):
751
+ super().__init__()
752
+ self.text_proj = nn.Linear(text_hidden_size, hidden_size)
753
+ self.vis_proj = nn.Linear(visual_embed_dim, hidden_size)
754
+ self.time_proj = nn.Sequential(
755
+ nn.Linear(1, hidden_size),
756
+ nn.SiLU(),
757
+ nn.Linear(hidden_size, hidden_size),
758
+ )
759
+ self.fuse = nn.Sequential(
760
+ nn.Linear(hidden_size * 3, hidden_size),
761
+ nn.SiLU(),
762
+ nn.Linear(hidden_size, latent_channels),
763
+ )
764
+ self.latent_channels = latent_channels
765
+
766
+ def forward(
767
+ self,
768
+ text_embeddings: torch.Tensor,
769
+ visual_embeddings_list: Optional[List[torch.Tensor]],
770
+ timesteps: torch.Tensor,
771
+ latent_shape: torch.Size,
772
+ ) -> torch.Tensor:
773
+ """
774
+ Args:
775
+ text_embeddings: [B, L, D_txt]
776
+ visual_embeddings_list: list of [B, L_vis, D_vis] or None
777
+ timesteps: [S] step values in [0,1]
778
+ latent_shape: target latent shape (B, C, ...)
779
+ Returns:
780
+ policy_velocities: [S, *latent_shape]
781
+ """
782
+ device = text_embeddings.device
783
+ dtype = text_embeddings.dtype
784
+ batch_size = text_embeddings.shape[0]
785
+
786
+ text_ctx = text_embeddings.mean(dim=1)
787
+
788
+ if visual_embeddings_list:
789
+ vis_tokens = [v.mean(dim=1) for v in visual_embeddings_list]
790
+ vis_ctx = torch.stack(vis_tokens, dim=0).mean(dim=0)
791
+ else:
792
+ vis_ctx = torch.zeros_like(text_ctx)
793
+
794
+ txt_feat = self.text_proj(text_ctx)
795
+ vis_feat = self.vis_proj(vis_ctx.to(device=device, dtype=dtype))
796
+
797
+ time_feat = self.time_proj(timesteps.unsqueeze(-1).to(device=device, dtype=dtype))
798
+
799
+ velocities = []
800
+ for t_feat in time_feat:
801
+ fused = torch.cat([txt_feat, vis_feat, t_feat.expand_as(txt_feat)], dim=-1)
802
+ step_token = self.fuse(fused).tanh() # [B, C]
803
+
804
+ step_field = step_token
805
+ while len(step_field.shape) < len(latent_shape):
806
+ step_field = step_field.unsqueeze(-1)
807
+ step_field = step_field.expand(batch_size, *latent_shape[1:])
808
+ velocities.append(step_field)
809
+
810
+ policy_velocities = torch.stack(velocities, dim=0)
811
+ return policy_velocities
812
+
813
  # -----------------------------------------------------------------------------
814
  # 5. The "Fancy" Pipeline
815
  # -----------------------------------------------------------------------------
 
825
  tokenizer: CLIPTokenizer
826
  text_encoder: CLIPTextModel
827
  vae: Any # AutoencoderKL
828
+ scheduler: Union[DDIMScheduler, OmniScheduler]
829
 
830
  _optional_components = ["visual_encoder"]
831
 
 
835
  vae: Any,
836
  text_encoder: CLIPTextModel,
837
  tokenizer: CLIPTokenizer,
838
+ scheduler: Union[DDIMScheduler, OmniScheduler],
839
  visual_encoder: Optional[Any] = None,
840
  ):
841
  super().__init__()
 
873
 
874
  self._is_compiled = False
875
  self._is_fp8_enabled = False
876
+ self.policy_network = PiFlowPolicyNetwork(
877
+ text_hidden_size=self.text_encoder.config.hidden_size,
878
+ visual_embed_dim=self.model.config.visual_embed_dim,
879
+ latent_channels=self.model.config.in_channels,
880
+ hidden_size=min(1024, self.model.config.hidden_size),
881
+ )
882
 
883
  def enable_fp8_quantization(self):
884
  """Enable FP8 quantization for faster inference"""
 
947
 
948
  return self
949
 
950
+ def _predict_policy_trajectory(
951
+ self,
952
+ text_embeddings: torch.Tensor,
953
+ visual_embeddings: torch.Tensor,
954
+ device: torch.device,
955
+ total_steps: int,
956
+ ) -> Optional[torch.Tensor]:
957
+ """
958
+ Predict coarse-stage velocity trajectory in one shot using the π-Flow policy network.
959
+ """
960
+ if self.policy_network is None or total_steps <= 0:
961
+ return None
962
+ # Keep policy network on the same device
963
+ self.policy_network = self.policy_network.to(device=device, dtype=text_embeddings.dtype)
964
+ time_grid = torch.linspace(0, 1, total_steps, device=device, dtype=text_embeddings.dtype)
965
+ time_grid = torch.linspace(0, 1, total_steps, device=self.device, dtype=text_embeddings.dtype)
966
+ return self.policy_network(
967
+ text_embeddings=text_embeddings.detach(),
968
+ visual_embeddings_list=visual_embeddings_list,
969
+ timesteps=time_grid,
970
+ latent_shape=latents.shape,
971
+ )
972
+
973
  @torch.no_grad()
974
  def __call__(
975
  self,
 
990
  callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
991
  callback_steps: int = 1,
992
  use_optimized_inference: bool = True,
993
+ use_pi_flow_policy: bool = False,
994
  **kwargs,
995
  ):
996
  # Use optimized inference context
 
1016
  return_dict=return_dict,
1017
  callback=callback,
1018
  callback_steps=callback_steps,
1019
+ use_pi_flow_policy=use_pi_flow_policy,
1020
  **kwargs,
1021
  )
1022
 
 
1038
  return_dict: bool = True,
1039
  callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
1040
  callback_steps: int = 1,
1041
+ use_pi_flow_policy: bool = False,
1042
  **kwargs,
1043
  ):
1044
  # Validate and set default dimensions
 
1090
  visual_embeddings_list.append(vis_emb)
1091
 
1092
  # Prepare timesteps
1093
+ if isinstance(self.scheduler, OmniScheduler):
1094
+ # π-Flow / Flow Matching path
1095
+ self.scheduler.config.prediction_type = "velocity"
1096
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device, use_karras_sigmas=True)
1097
+ total_steps = len(self.scheduler.timesteps) - 1
1098
+ else:
1099
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
1100
+ timesteps = self.scheduler.timesteps
1101
+ total_steps = len(timesteps)
1102
 
1103
  # Initialize latent space
1104
  num_channels_latents = self.model.config.in_channels
 
1109
  latents = torch.randn(shape, generator=generator, device=self.device, dtype=text_embeddings.dtype)
1110
  latents = latents * self.scheduler.init_noise_sigma
1111
 
1112
+ if isinstance(self.scheduler, OmniScheduler):
1113
+ policy_velocities = None
1114
+ if use_pi_flow_policy:
1115
+ policy_velocities = self._compute_policy_trajectory(
1116
+ text_embeddings=text_embeddings,
1117
+ visual_embeddings_list=visual_embeddings_list,
1118
+ latents=latents,
1119
+ total_steps=total_steps,
1120
+ )
1121
+
1122
+ with self.progress_bar(total=total_steps) as progress_bar:
1123
+ for step_idx in range(total_steps):
1124
+ t_val = self.scheduler.timesteps[step_idx]
1125
+
1126
+ use_policy_step = (
1127
+ use_pi_flow_policy and policy_velocities is not None and step_idx < self.scheduler.coarse_steps
1128
+ )
1129
+
1130
+ if use_policy_step:
1131
+ model_output = policy_velocities[step_idx]
1132
+ model_fn = None
1133
+ else:
1134
+ latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
1135
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t_val)
1136
+
1137
+ with self.model_optimizer.autocast_context():
1138
+ noise_pred = self.model(
1139
+ hidden_states=latent_model_input,
1140
+ timestep=t_val,
1141
+ encoder_hidden_states=torch.cat([text_embeddings] * 2),
1142
+ visual_conditions=visual_embeddings_list * 2 if visual_embeddings_list else None,
1143
+ video_frames=num_frames
1144
+ ).sample
1145
+
1146
+ if guidance_scale > 1.0:
1147
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1148
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1149
+
1150
+ model_output = noise_pred
1151
+ model_fn = None # extendable to second eval for higher-order solvers
1152
+
1153
+ step_output = self.scheduler.step(
1154
+ model_output=model_output,
1155
+ timestep=step_idx,
1156
+ sample=latents,
1157
+ model_fn=model_fn,
1158
+ )
1159
+ latents = step_output.prev_sample if hasattr(step_output, "prev_sample") else step_output[0]
1160
+
1161
+ if callback is not None and step_idx % callback_steps == 0:
1162
+ callback(step_idx, t_val, latents)
1163
+
1164
+ progress_bar.update()
1165
+ else:
1166
+ # Compatible with original DDIM/standard scheduler
1167
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1168
+ for i, t in enumerate(timesteps):
1169
+ latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
1170
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1171
+
1172
+ # Use mixed precision autocast
1173
+ with self.model_optimizer.autocast_context():
1174
+ noise_pred = self.model(
1175
+ hidden_states=latent_model_input,
1176
+ timestep=t,
1177
+ encoder_hidden_states=torch.cat([text_embeddings] * 2),
1178
+ visual_conditions=visual_embeddings_list * 2 if visual_embeddings_list else None,
1179
+ video_frames=num_frames
1180
+ ).sample
1181
+
1182
+ # Apply classifier-free guidance
1183
+ if guidance_scale > 1.0:
1184
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1185
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1186
+
1187
+ latents = self.scheduler.step(noise_pred, t, latents, eta=eta).prev_sample
1188
+
1189
+ # Call callback if provided
1190
+ if callback is not None and i % callback_steps == 0:
1191
+ callback(i, t, latents)
1192
+
1193
+ progress_bar.update()
1194
 
1195
  # Decode latents with proper post-processing
1196
  if output_type == "latent":
scheduling_omni.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 OmniEdit Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ OmniScheduler: Unified Hybrid Flow-Diffusion Sampler
17
+
18
+ Key features:
19
+ 1. Policy-driven velocity field inspired by π-Flow for few-step generation.
20
+ 2. Multi-stage sampling (coarse → refine) to balance speed and detail.
21
+ 3. High-order ODE solvers (RK2 / RK4) for faster convergence.
22
+ 4. Hybrid flow-matching + diffusion sampling for stable trajectories.
23
+ """
24
+
25
+ import math
26
+ from dataclasses import dataclass
27
+ from typing import Optional, Tuple, Union, List
28
+
29
+ import torch
30
+
31
+ from ..configuration_utils import ConfigMixin, register_to_config
32
+ from ..utils import BaseOutput
33
+ from ..utils.torch_utils import randn_tensor
34
+ from .scheduling_utils import SchedulerMixin, SchedulerOutput
35
+
36
+
37
+ @dataclass
38
+ class OmniSchedulerOutput(BaseOutput):
39
+ """
40
+ Output container for OmniScheduler.
41
+
42
+ Args:
43
+ prev_sample (`torch.Tensor`):
44
+ Sample at the previous timestep (x_{t-1}) for the next denoising step.
45
+ prev_sample_mean (`torch.Tensor`):
46
+ Mean estimate of the sample for inspection/debugging.
47
+ velocity (`torch.Tensor`, *optional*):
48
+ Policy-predicted velocity field for flow matching.
49
+ stage (`int`):
50
+ Current stage (0: coarse, 1: refine).
51
+ """
52
+
53
+ prev_sample: torch.Tensor
54
+ prev_sample_mean: torch.Tensor
55
+ velocity: Optional[torch.Tensor] = None
56
+ stage: int = 0
57
+
58
+
59
+ class OmniScheduler(SchedulerMixin, ConfigMixin):
60
+ """
61
+ `OmniScheduler` - Unified Hybrid Flow-Diffusion Sampler
62
+
63
+ Combines flow matching and high-order ODE solvers to achieve high-quality
64
+ image/video generation in very few steps. Supports T2I, I2I, and T2V in one sampler.
65
+
66
+ Key innovations:
67
+ - Policy-driven velocity field: predicts the whole path in one forward pass.
68
+ - Multi-stage sampling: coarse generation + optional refinement.
69
+ - High-order ODE solvers (RK4/RK2) for faster convergence.
70
+ - Hybrid flow/diffusion sampling for stable trajectories.
71
+
72
+ Args:
73
+ num_train_timesteps (`int`, defaults to 1000):
74
+ Number of diffusion training steps.
75
+ num_inference_steps (`int`, defaults to 4):
76
+ Inference steps; supports few-step (4–8) generation.
77
+ sigma_min (`float`, defaults to 0.002):
78
+ Minimum noise level.
79
+ sigma_max (`float`, defaults to 80.0):
80
+ Maximum noise level.
81
+ sigma_data (`float`, defaults to 0.5):
82
+ Data std for preconditioning.
83
+ rho (`float`, defaults to 7.0):
84
+ Karras schedule parameter.
85
+ solver_order (`int`, defaults to 2):
86
+ Solver order (1: Euler, 2: RK2/Heun, 4: RK4).
87
+ use_flow_matching (`bool`, defaults to True):
88
+ Whether to use flow-matching mode.
89
+ use_multi_stage (`bool`, defaults to True):
90
+ Whether to use multi-stage sampling.
91
+ coarse_ratio (`float`, defaults to 0.7):
92
+ Fraction of steps allocated to coarse stage.
93
+ snr (`float`, defaults to 0.15):
94
+ SNR factor for correction step size.
95
+ prediction_type (`str`, defaults to "velocity"):
96
+ Prediction type ("velocity", "epsilon", "sample").
97
+ """
98
+
99
+ order = 2 # Default to 2nd-order solver
100
+
101
+ @register_to_config
102
+ def __init__(
103
+ self,
104
+ num_train_timesteps: int = 1000,
105
+ num_inference_steps: int = 4,
106
+ sigma_min: float = 0.002,
107
+ sigma_max: float = 80.0,
108
+ sigma_data: float = 0.5,
109
+ rho: float = 7.0,
110
+ solver_order: int = 2,
111
+ use_flow_matching: bool = True,
112
+ use_multi_stage: bool = True,
113
+ coarse_ratio: float = 0.7,
114
+ snr: float = 0.15,
115
+ prediction_type: str = "velocity",
116
+ ):
117
+ # Initial noise sigma
118
+ self.init_noise_sigma = sigma_max
119
+
120
+ # Mutable state
121
+ self.timesteps = None
122
+ self.sigmas = None
123
+ self.discrete_sigmas = None
124
+ self.num_inference_steps = num_inference_steps
125
+
126
+ # Multi-stage state
127
+ self.current_stage = 0
128
+ self.coarse_steps = int(num_inference_steps * coarse_ratio)
129
+ self.refine_steps = num_inference_steps - self.coarse_steps
130
+
131
+ # Initialize timesteps and sigmas
132
+ self._init_timesteps_and_sigmas()
133
+
134
+ def _init_timesteps_and_sigmas(self):
135
+ """Initialize timesteps and sigma values with Karras schedule."""
136
+ num_steps = self.config.num_inference_steps
137
+ sigma_min = self.config.sigma_min
138
+ sigma_max = self.config.sigma_max
139
+ rho = self.config.rho
140
+
141
+ # Karras sigma schedule: σ_i = (σ_max^(1/ρ) + i/(N-1) * (σ_min^(1/ρ) - σ_max^(1/ρ)))^ρ
142
+ ramp = torch.linspace(0, 1, num_steps + 1)
143
+ min_inv_rho = sigma_min ** (1 / rho)
144
+ max_inv_rho = sigma_max ** (1 / rho)
145
+ self.sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
146
+
147
+ # Append final sigma = 0
148
+ self.sigmas = torch.cat([self.sigmas, torch.zeros(1)])
149
+
150
+ # Timesteps from 1 to 0
151
+ self.timesteps = torch.linspace(1, 0, num_steps + 1)
152
+
153
+ # Discrete sigmas for compatibility
154
+ self.discrete_sigmas = self.sigmas[:-1]
155
+
156
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
157
+ """
158
+ Precondition model input with sigma-dependent scaling.
159
+
160
+ Args:
161
+ sample (`torch.Tensor`): Input sample.
162
+ timestep (`int`, *optional*): Current timestep.
163
+
164
+ Returns:
165
+ `torch.Tensor`: Scaled input sample.
166
+ """
167
+ if timestep is None:
168
+ return sample
169
+
170
+ # Fetch current sigma
171
+ step_index = (self.timesteps == timestep).nonzero()
172
+ if len(step_index) == 0:
173
+ return sample
174
+ sigma = self.sigmas[step_index[0]].to(sample.device)
175
+
176
+ # Preconditioning scale: c_in = 1 / sqrt(σ² + σ_data²)
177
+ sigma_data = self.config.sigma_data
178
+ c_in = 1 / (sigma**2 + sigma_data**2).sqrt()
179
+
180
+ return sample * c_in
181
+
182
+ def set_timesteps(
183
+ self,
184
+ num_inference_steps: int,
185
+ device: Union[str, torch.device] = None,
186
+ use_karras_sigmas: bool = True,
187
+ ):
188
+ """
189
+ Set inference timesteps.
190
+
191
+ Args:
192
+ num_inference_steps (`int`): Number of inference steps.
193
+ device (`str` or `torch.device`, *optional*): Target device.
194
+ use_karras_sigmas (`bool`): Whether to use Karras sigma schedule.
195
+ """
196
+ self.num_inference_steps = num_inference_steps
197
+ self.coarse_steps = int(num_inference_steps * self.config.coarse_ratio)
198
+ self.refine_steps = num_inference_steps - self.coarse_steps
199
+
200
+ sigma_min = self.config.sigma_min
201
+ sigma_max = self.config.sigma_max
202
+ rho = self.config.rho
203
+
204
+ if use_karras_sigmas:
205
+ # Karras sigma schedule
206
+ ramp = torch.linspace(0, 1, num_inference_steps + 1)
207
+ min_inv_rho = sigma_min ** (1 / rho)
208
+ max_inv_rho = sigma_max ** (1 / rho)
209
+ self.sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
210
+ else:
211
+ # Linear sigma schedule
212
+ self.sigmas = torch.linspace(sigma_max, sigma_min, num_inference_steps + 1)
213
+
214
+ self.sigmas = torch.cat([self.sigmas, torch.zeros(1)])
215
+ self.timesteps = torch.linspace(1, 0, num_inference_steps + 1)
216
+ self.discrete_sigmas = self.sigmas[:-1]
217
+
218
+ if device is not None:
219
+ self.sigmas = self.sigmas.to(device)
220
+ self.timesteps = self.timesteps.to(device)
221
+ self.discrete_sigmas = self.discrete_sigmas.to(device)
222
+
223
+ def _get_velocity_from_prediction(
224
+ self,
225
+ model_output: torch.Tensor,
226
+ sample: torch.Tensor,
227
+ sigma: torch.Tensor,
228
+ ) -> torch.Tensor:
229
+ """
230
+ Convert model prediction to velocity field based on prediction type.
231
+
232
+ Args:
233
+ model_output: Model output.
234
+ sample: Current sample.
235
+ sigma: Current sigma.
236
+
237
+ Returns:
238
+ velocity: Velocity field (dx/dt).
239
+ """
240
+ sigma_data = self.config.sigma_data
241
+ prediction_type = self.config.prediction_type
242
+
243
+ # Ensure sigma has proper shape
244
+ while len(sigma.shape) < len(sample.shape):
245
+ sigma = sigma.unsqueeze(-1)
246
+
247
+ if prediction_type == "velocity":
248
+ # Direct velocity prediction
249
+ velocity = model_output
250
+ elif prediction_type == "epsilon":
251
+ # From epsilon prediction: v = (x - σ*ε) / σ - x/σ = -ε
252
+ # Flow matching: dx/dt ≈ -ε * σ
253
+ velocity = -model_output * sigma
254
+ elif prediction_type == "sample":
255
+ # From sample prediction: v = (x_pred - x) / σ
256
+ velocity = (model_output - sample) / sigma.clamp(min=1e-8)
257
+ else:
258
+ raise ValueError(f"Unknown prediction_type: {prediction_type}")
259
+
260
+ return velocity
261
+
262
+ def step_euler(
263
+ self,
264
+ model_output: torch.Tensor,
265
+ timestep: int,
266
+ sample: torch.Tensor,
267
+ dt: float,
268
+ generator: Optional[torch.Generator] = None,
269
+ ) -> torch.Tensor:
270
+ """
271
+ Single Euler update (1st order).
272
+
273
+ Args:
274
+ model_output: Model output.
275
+ timestep: Current timestep index.
276
+ sample: Current sample.
277
+ dt: Step size.
278
+ generator: Random generator.
279
+
280
+ Returns:
281
+ Updated sample.
282
+ """
283
+ sigma = self.sigmas[timestep].to(sample.device)
284
+ velocity = self._get_velocity_from_prediction(model_output, sample, sigma)
285
+
286
+ # Euler update: x_{t+dt} = x_t + v * dt
287
+ prev_sample = sample + velocity * dt
288
+
289
+ return prev_sample, velocity
290
+
291
+ def step_heun(
292
+ self,
293
+ model_output: torch.Tensor,
294
+ timestep: int,
295
+ sample: torch.Tensor,
296
+ dt: float,
297
+ model_fn=None,
298
+ generator: Optional[torch.Generator] = None,
299
+ ) -> torch.Tensor:
300
+ """
301
+ Heun's method (improved Euler, 2nd order).
302
+
303
+ Args:
304
+ model_output: Initial model output.
305
+ timestep: Current timestep index.
306
+ sample: Current sample.
307
+ dt: Step size.
308
+ model_fn: Model function for second evaluation.
309
+ generator: Random generator.
310
+
311
+ Returns:
312
+ Updated sample.
313
+ """
314
+ sigma = self.sigmas[timestep].to(sample.device)
315
+ velocity_1 = self._get_velocity_from_prediction(model_output, sample, sigma)
316
+
317
+ # Predictor step (Euler)
318
+ sample_pred = sample + velocity_1 * dt
319
+
320
+ if model_fn is not None and timestep + 1 < len(self.sigmas) - 1:
321
+ # Corrector step
322
+ sigma_next = self.sigmas[timestep + 1].to(sample.device)
323
+ model_output_2 = model_fn(sample_pred, sigma_next)
324
+ velocity_2 = self._get_velocity_from_prediction(model_output_2, sample_pred, sigma_next)
325
+
326
+ # Heun update: x_{t+dt} = x_t + (v_1 + v_2) / 2 * dt
327
+ prev_sample = sample + (velocity_1 + velocity_2) / 2 * dt
328
+ velocity = (velocity_1 + velocity_2) / 2
329
+ else:
330
+ prev_sample = sample_pred
331
+ velocity = velocity_1
332
+
333
+ return prev_sample, velocity
334
+
335
+ def step_rk4(
336
+ self,
337
+ model_output: torch.Tensor,
338
+ timestep: int,
339
+ sample: torch.Tensor,
340
+ dt: float,
341
+ model_fn=None,
342
+ generator: Optional[torch.Generator] = None,
343
+ ) -> torch.Tensor:
344
+ """
345
+ Runge-Kutta 4th-order method.
346
+
347
+ Args:
348
+ model_output: Initial model output.
349
+ timestep: Current timestep index.
350
+ sample: Current sample.
351
+ dt: Step size.
352
+ model_fn: Model function for intermediate evaluations.
353
+ generator: Random generator.
354
+
355
+ Returns:
356
+ Updated sample.
357
+ """
358
+ sigma = self.sigmas[timestep].to(sample.device)
359
+
360
+ # k1
361
+ k1 = self._get_velocity_from_prediction(model_output, sample, sigma)
362
+
363
+ if model_fn is None:
364
+ # Fallback to Euler if no model function
365
+ return sample + k1 * dt, k1
366
+
367
+ # Mid sigma values
368
+ sigma_mid = (self.sigmas[timestep] + self.sigmas[min(timestep + 1, len(self.sigmas) - 2)]) / 2
369
+ sigma_mid = sigma_mid.to(sample.device)
370
+ sigma_next = self.sigmas[min(timestep + 1, len(self.sigmas) - 2)].to(sample.device)
371
+
372
+ # k2
373
+ sample_2 = sample + k1 * (dt / 2)
374
+ model_output_2 = model_fn(sample_2, sigma_mid)
375
+ k2 = self._get_velocity_from_prediction(model_output_2, sample_2, sigma_mid)
376
+
377
+ # k3
378
+ sample_3 = sample + k2 * (dt / 2)
379
+ model_output_3 = model_fn(sample_3, sigma_mid)
380
+ k3 = self._get_velocity_from_prediction(model_output_3, sample_3, sigma_mid)
381
+
382
+ # k4
383
+ sample_4 = sample + k3 * dt
384
+ model_output_4 = model_fn(sample_4, sigma_next)
385
+ k4 = self._get_velocity_from_prediction(model_output_4, sample_4, sigma_next)
386
+
387
+ # RK4 update: x_{t+dt} = x_t + (k1 + 2*k2 + 2*k3 + k4) / 6 * dt
388
+ velocity = (k1 + 2 * k2 + 2 * k3 + k4) / 6
389
+ prev_sample = sample + velocity * dt
390
+
391
+ return prev_sample, velocity
392
+
393
+ def step(
394
+ self,
395
+ model_output: torch.Tensor,
396
+ timestep: int,
397
+ sample: torch.Tensor,
398
+ generator: Optional[torch.Generator] = None,
399
+ return_dict: bool = True,
400
+ model_fn=None,
401
+ ) -> Union[OmniSchedulerOutput, Tuple]:
402
+ """
403
+ Execute one sampling step, auto-selecting solver and strategy.
404
+
405
+ Args:
406
+ model_output (`torch.Tensor`): Model output.
407
+ timestep (`int`): Current timestep index.
408
+ sample (`torch.Tensor`): Current sample.
409
+ generator (`torch.Generator`, *optional*): Random generator.
410
+ return_dict (`bool`): Return dict format.
411
+ model_fn: Model function for higher-order solvers.
412
+
413
+ Returns:
414
+ `OmniSchedulerOutput` or `tuple`.
415
+ """
416
+ if self.timesteps is None:
417
+ raise ValueError(
418
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
419
+ )
420
+
421
+ # Compute step size
422
+ if timestep + 1 < len(self.sigmas):
423
+ sigma_curr = self.sigmas[timestep]
424
+ sigma_next = self.sigmas[timestep + 1]
425
+ dt = sigma_next - sigma_curr
426
+ else:
427
+ dt = -self.sigmas[timestep]
428
+
429
+ # Determine current stage
430
+ if self.config.use_multi_stage:
431
+ if timestep < self.coarse_steps:
432
+ self.current_stage = 0 # coarse
433
+ else:
434
+ self.current_stage = 1 # refine
435
+
436
+ # Choose solver order based on stage
437
+ solver_order = self.config.solver_order
438
+
439
+ # Coarse stage may use lower order to speed up
440
+ if self.current_stage == 0 and self.config.use_multi_stage:
441
+ effective_order = min(solver_order, 2)
442
+ else:
443
+ effective_order = solver_order
444
+
445
+ if effective_order == 1:
446
+ prev_sample, velocity = self.step_euler(
447
+ model_output, timestep, sample, dt, generator
448
+ )
449
+ elif effective_order == 2:
450
+ prev_sample, velocity = self.step_heun(
451
+ model_output, timestep, sample, dt, model_fn, generator
452
+ )
453
+ elif effective_order >= 4:
454
+ prev_sample, velocity = self.step_rk4(
455
+ model_output, timestep, sample, dt, model_fn, generator
456
+ )
457
+ else:
458
+ prev_sample, velocity = self.step_euler(
459
+ model_output, timestep, sample, dt, generator
460
+ )
461
+
462
+ prev_sample_mean = prev_sample.clone()
463
+
464
+ if not return_dict:
465
+ return (prev_sample, prev_sample_mean, velocity, self.current_stage)
466
+
467
+ return OmniSchedulerOutput(
468
+ prev_sample=prev_sample,
469
+ prev_sample_mean=prev_sample_mean,
470
+ velocity=velocity,
471
+ stage=self.current_stage,
472
+ )
473
+
474
+ def apply_policy_trajectory(
475
+ self,
476
+ policy_velocities: Optional[torch.Tensor],
477
+ sample: torch.Tensor,
478
+ return_all: bool = False,
479
+ ):
480
+ """
481
+ Run inference directly using the velocity trajectory from a π-Flow policy network.
482
+
483
+ Args:
484
+ policy_velocities: Velocity field sequence of shape [S, *sample.shape].
485
+ sample: Initial noise sample.
486
+ return_all: Whether to return intermediate results for each step.
487
+ """
488
+ if policy_velocities is None:
489
+ return (sample, []) if return_all else sample
490
+
491
+ steps = min(policy_velocities.shape[0], len(self.sigmas) - 1)
492
+ traj = []
493
+ curr = sample
494
+ for i in range(steps):
495
+ out = self.step(policy_velocities[i], i, curr, return_dict=True)
496
+ curr = out.prev_sample
497
+ if return_all:
498
+ traj.append(curr)
499
+
500
+ return (curr, traj) if return_all else curr
501
+
502
+ def step_correct(
503
+ self,
504
+ model_output: torch.Tensor,
505
+ sample: torch.Tensor,
506
+ generator: Optional[torch.Generator] = None,
507
+ return_dict: bool = True,
508
+ ) -> Union[SchedulerOutput, Tuple]:
509
+ """
510
+ Corrective step to improve sample quality.
511
+
512
+ Args:
513
+ model_output (`torch.Tensor`): Model output.
514
+ sample (`torch.Tensor`): Current sample.
515
+ generator (`torch.Generator`, *optional*): Random generator.
516
+ return_dict (`bool`): Return dict format.
517
+
518
+ Returns:
519
+ `SchedulerOutput` or `tuple`.
520
+ """
521
+ if self.timesteps is None:
522
+ raise ValueError(
523
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
524
+ )
525
+
526
+ # Generate correction noise
527
+ noise = randn_tensor(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
528
+
529
+ # Compute step size
530
+ grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean()
531
+ noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
532
+ step_size = (self.config.snr * noise_norm / grad_norm.clamp(min=1e-8)) ** 2 * 2
533
+ step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
534
+
535
+ # Adjust shape
536
+ step_size = step_size.flatten()
537
+ while len(step_size.shape) < len(sample.shape):
538
+ step_size = step_size.unsqueeze(-1)
539
+
540
+ # Correction update
541
+ prev_sample_mean = sample + step_size * model_output
542
+ prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
543
+
544
+ if not return_dict:
545
+ return (prev_sample,)
546
+
547
+ return SchedulerOutput(prev_sample=prev_sample)
548
+
549
+ def add_noise(
550
+ self,
551
+ original_samples: torch.Tensor,
552
+ noise: torch.Tensor,
553
+ timesteps: torch.Tensor,
554
+ ) -> torch.Tensor:
555
+ """
556
+ Add noise to original samples.
557
+
558
+ Args:
559
+ original_samples: Original samples.
560
+ noise: Noise tensor.
561
+ timesteps: Timestep indices.
562
+
563
+ Returns:
564
+ Noised samples.
565
+ """
566
+ timesteps = timesteps.to(original_samples.device)
567
+
568
+ # Get corresponding sigma
569
+ sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps]
570
+
571
+ # Adjust shape
572
+ while len(sigmas.shape) < len(original_samples.shape):
573
+ sigmas = sigmas.unsqueeze(-1)
574
+
575
+ # Add noise
576
+ if noise is not None:
577
+ noisy_samples = original_samples + noise * sigmas
578
+ else:
579
+ noisy_samples = original_samples + torch.randn_like(original_samples) * sigmas
580
+
581
+ return noisy_samples
582
+
583
+ def get_flow_velocity(
584
+ self,
585
+ x_0: torch.Tensor,
586
+ x_1: torch.Tensor,
587
+ t: torch.Tensor,
588
+ ) -> torch.Tensor:
589
+ """
590
+ Compute target velocity field for flow matching.
591
+
592
+ Used to train the policy network.
593
+
594
+ Args:
595
+ x_0: Start sample (noise).
596
+ x_1: Target sample (data).
597
+ t: Time point [0, 1].
598
+
599
+ Returns:
600
+ Target velocity field.
601
+ """
602
+ # Linear interpolation path: x_t = (1-t) * x_0 + t * x_1
603
+ # Velocity field: v = dx/dt = x_1 - x_0
604
+ while len(t.shape) < len(x_0.shape):
605
+ t = t.unsqueeze(-1)
606
+
607
+ velocity = x_1 - x_0
608
+ return velocity
609
+
610
+ def get_interpolated_sample(
611
+ self,
612
+ x_0: torch.Tensor,
613
+ x_1: torch.Tensor,
614
+ t: torch.Tensor,
615
+ ) -> torch.Tensor:
616
+ """
617
+ Get interpolated sample on the flow-matching path.
618
+
619
+ Args:
620
+ x_0: Start sample (noise).
621
+ x_1: Target sample (data).
622
+ t: Time point [0, 1].
623
+
624
+ Returns:
625
+ Interpolated sample x_t.
626
+ """
627
+ while len(t.shape) < len(x_0.shape):
628
+ t = t.unsqueeze(-1)
629
+
630
+ x_t = (1 - t) * x_0 + t * x_1
631
+ return x_t
632
+
633
+ def __len__(self):
634
+ return self.config.num_train_timesteps