mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 18:28:10 +00:00
Support Torch Compile (#1368)
* support simple compile * add support for compile * minor fix * minor fix * minor fix
This commit is contained in:
@@ -339,6 +339,38 @@ class BasePipeline(torch.nn.Module):
|
||||
noise_pred = noise_pred_posi
|
||||
return noise_pred
|
||||
|
||||
def compile_pipeline(self, mode: str = "default", dynamic: bool = True, fullgraph: bool = False, compile_models: list = None, **kwargs):
|
||||
"""
|
||||
compile the pipeline with torch.compile. The models that will be compiled are determined by the `compilable_models` attribute of the pipeline.
|
||||
If a model has `_repeated_blocks` attribute, we will compile these blocks with regional compilation. Otherwise, we will compile the whole model.
|
||||
See https://docs.pytorch.org/docs/stable/generated/torch.compile.html#torch.compile for details about compilation arguments.
|
||||
Args:
|
||||
mode: The compilation mode, which will be passed to `torch.compile`, options are "default", "reduce-overhead", "max-autotune" and "max-autotune-no-cudagraphs. Default to "default".
|
||||
dynamic: Whether to enable dynamic graph compilation to support dynamic input shapes, which will be passed to `torch.compile`. Default to True (recommended).
|
||||
fullgraph: Whether to use full graph compilation, which will be passed to `torch.compile`. Default to False (recommended).
|
||||
compile_models: The list of model names to be compiled. If None, we will compile the models in `pipeline.compilable_models`. Default to None.
|
||||
**kwargs: Other arguments for `torch.compile`.
|
||||
"""
|
||||
compile_models = compile_models or getattr(self, "compilable_models", [])
|
||||
if len(compile_models) == 0:
|
||||
print("No compilable models in the pipeline. Skip compilation.")
|
||||
return
|
||||
for name in compile_models:
|
||||
model = getattr(self, name, None)
|
||||
if model is None:
|
||||
print(f"Model '{name}' not found in the pipeline.")
|
||||
continue
|
||||
repeated_blocks = getattr(model, "_repeated_blocks", None)
|
||||
# regional compilation for repeated blocks.
|
||||
if repeated_blocks is not None:
|
||||
for submod in model.modules():
|
||||
if submod.__class__.__name__ in repeated_blocks:
|
||||
submod.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs)
|
||||
# compile the whole model.
|
||||
else:
|
||||
model.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs)
|
||||
print(f"{name} is compiled with mode={mode}, dynamic={dynamic}, fullgraph={fullgraph}.")
|
||||
|
||||
|
||||
class PipelineUnitGraph:
|
||||
def __init__(self):
|
||||
|
||||
@@ -1270,6 +1270,9 @@ class LLMAdapter(nn.Module):
|
||||
|
||||
|
||||
class AnimaDiT(MiniTrainDIT):
|
||||
|
||||
_repeated_blocks = ["Block"]
|
||||
|
||||
def __init__(self):
|
||||
kwargs = {'image_model': 'anima', 'max_img_h': 240, 'max_img_w': 240, 'max_frames': 128, 'in_channels': 16, 'out_channels': 16, 'patch_spatial': 2, 'patch_temporal': 1, 'model_channels': 2048, 'concat_padding_mask': True, 'crossattn_emb_channels': 1024, 'pos_emb_cls': 'rope3d', 'pos_emb_learnable': True, 'pos_emb_interpolation': 'crop', 'min_fps': 1, 'max_fps': 30, 'use_adaln_lora': True, 'adaln_lora_dim': 256, 'num_blocks': 28, 'num_heads': 16, 'extra_per_block_abs_pos_emb': False, 'rope_h_extrapolation_ratio': 4.0, 'rope_w_extrapolation_ratio': 4.0, 'rope_t_extrapolation_ratio': 1.0, 'extra_h_extrapolation_ratio': 1.0, 'extra_w_extrapolation_ratio': 1.0, 'extra_t_extrapolation_ratio': 1.0, 'rope_enable_fps_modulation': False, 'dtype': torch.bfloat16, 'device': None, 'operations': torch.nn}
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -879,6 +879,9 @@ class Flux2Modulation(nn.Module):
|
||||
|
||||
|
||||
class Flux2DiT(torch.nn.Module):
|
||||
|
||||
_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
|
||||
@@ -275,6 +275,9 @@ class AdaLayerNormContinuous(torch.nn.Module):
|
||||
|
||||
|
||||
class FluxDiT(torch.nn.Module):
|
||||
|
||||
_repeated_blocks = ["FluxJointTransformerBlock", "FluxSingleTransformerBlock"]
|
||||
|
||||
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
|
||||
super().__init__()
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
|
||||
@@ -1280,6 +1280,7 @@ class LTXModel(torch.nn.Module):
|
||||
LTX model transformer implementation.
|
||||
This class implements the transformer blocks for the LTX model.
|
||||
"""
|
||||
_repeated_blocks = ["BasicAVTransformerBlock"]
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
|
||||
@@ -549,6 +549,9 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class QwenImageDiT(torch.nn.Module):
|
||||
|
||||
_repeated_blocks = ["QwenImageTransformerBlock"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 60,
|
||||
|
||||
@@ -336,6 +336,9 @@ class WanToDanceInjector(nn.Module):
|
||||
|
||||
|
||||
class WanModel(torch.nn.Module):
|
||||
|
||||
_repeated_blocks = ["DiTBlock"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
|
||||
@@ -326,6 +326,7 @@ class RopeEmbedder:
|
||||
class ZImageDiT(nn.Module):
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ZImageTransformerBlock"]
|
||||
_repeated_blocks = ["ZImageTransformerBlock"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -39,6 +39,7 @@ class AnimaImagePipeline(BasePipeline):
|
||||
AnimaUnit_PromptEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_anima
|
||||
self.compilable_models = ["dit"]
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -42,6 +42,7 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
Flux2Unit_ImageIDs(),
|
||||
]
|
||||
self.model_fn = model_fn_flux2
|
||||
self.compilable_models = ["dit"]
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -103,6 +103,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
FluxImageUnit_LoRAEncode(),
|
||||
]
|
||||
self.model_fn = model_fn_flux_image
|
||||
self.compilable_models = ["dit"]
|
||||
self.lora_loader = FluxLoRALoader
|
||||
|
||||
def enable_lora_merger(self):
|
||||
|
||||
@@ -76,6 +76,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
LTX2AudioVideoUnit_SetScheduleStage2(),
|
||||
]
|
||||
self.model_fn = model_fn_ltx2
|
||||
self.compilable_models = ["dit"]
|
||||
|
||||
self.default_negative_prompt = {
|
||||
"LTX-2": (
|
||||
|
||||
@@ -52,6 +52,7 @@ class MovaAudioVideoPipeline(BasePipeline):
|
||||
MovaAudioVideoUnit_UnifiedSequenceParallel(),
|
||||
]
|
||||
self.model_fn = model_fn_mova_audio_video
|
||||
self.compilable_models = ["video_dit", "video_dit2", "audio_dit"]
|
||||
|
||||
def enable_usp(self):
|
||||
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward
|
||||
|
||||
@@ -56,6 +56,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
QwenImageUnit_BlockwiseControlNet(),
|
||||
]
|
||||
self.model_fn = model_fn_qwen_image
|
||||
self.compilable_models = ["dit"]
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -83,6 +83,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoPostUnit_S2V(),
|
||||
]
|
||||
self.model_fn = model_fn_wan_video
|
||||
self.compilable_models = ["dit", "dit2"]
|
||||
|
||||
|
||||
def enable_usp(self):
|
||||
|
||||
@@ -54,6 +54,7 @@ class ZImagePipeline(BasePipeline):
|
||||
ZImageUnit_PAIControlNet(),
|
||||
]
|
||||
self.model_fn = model_fn_z_image
|
||||
self.compilable_models = ["dit"]
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user