diff --git a/README.md b/README.md index 46e3634..39af15a 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ We believe that a well-developed open-source code framework can lower the thresh > DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update. > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. -- **January 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available. +- **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available. - **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/). diff --git a/README_zh.md b/README_zh.md index d98ceaa..98a6136 100644 --- a/README_zh.md +++ b/README_zh.md @@ -33,7 +33,7 @@ DiffSynth 目前包括两个开源项目: > 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。 -- **2026年1月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。 +- **2026年3月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。 - **2026年3月12日** 我们新增了 [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) 音视频生成模型的支持,模型支持的功能包括文生音视频、图生音视频、IC-LoRA控制、音频生视频、音视频局部Inpainting,框架支持完整的推理和训练功能。详细信息请参考 [文档](/docs/zh/Model_Details/LTX-2.md) 和 [示例代码](/examples/ltx2/)。 diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index face319..588f765 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -339,6 +339,38 @@ class BasePipeline(torch.nn.Module): noise_pred = noise_pred_posi return noise_pred + def compile_pipeline(self, mode: str = "default", dynamic: bool = True, fullgraph: bool = False, compile_models: list = None, **kwargs): + """ + compile the pipeline with torch.compile. The models that will be compiled are determined by the `compilable_models` attribute of the pipeline. + If a model has `_repeated_blocks` attribute, we will compile these blocks with regional compilation. Otherwise, we will compile the whole model. + See https://docs.pytorch.org/docs/stable/generated/torch.compile.html#torch.compile for details about compilation arguments. + Args: + mode: The compilation mode, which will be passed to `torch.compile`, options are "default", "reduce-overhead", "max-autotune" and "max-autotune-no-cudagraphs. Default to "default". + dynamic: Whether to enable dynamic graph compilation to support dynamic input shapes, which will be passed to `torch.compile`. Default to True (recommended). + fullgraph: Whether to use full graph compilation, which will be passed to `torch.compile`. Default to False (recommended). + compile_models: The list of model names to be compiled. If None, we will compile the models in `pipeline.compilable_models`. Default to None. + **kwargs: Other arguments for `torch.compile`. + """ + compile_models = compile_models or getattr(self, "compilable_models", []) + if len(compile_models) == 0: + print("No compilable models in the pipeline. Skip compilation.") + return + for name in compile_models: + model = getattr(self, name, None) + if model is None: + print(f"Model '{name}' not found in the pipeline.") + continue + repeated_blocks = getattr(model, "_repeated_blocks", None) + # regional compilation for repeated blocks. + if repeated_blocks is not None: + for submod in model.modules(): + if submod.__class__.__name__ in repeated_blocks: + submod.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs) + # compile the whole model. + else: + model.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs) + print(f"{name} is compiled with mode={mode}, dynamic={dynamic}, fullgraph={fullgraph}.") + class PipelineUnitGraph: def __init__(self): diff --git a/diffsynth/models/anima_dit.py b/diffsynth/models/anima_dit.py index dbd1407..d751980 100644 --- a/diffsynth/models/anima_dit.py +++ b/diffsynth/models/anima_dit.py @@ -1270,6 +1270,9 @@ class LLMAdapter(nn.Module): class AnimaDiT(MiniTrainDIT): + + _repeated_blocks = ["Block"] + def __init__(self): kwargs = {'image_model': 'anima', 'max_img_h': 240, 'max_img_w': 240, 'max_frames': 128, 'in_channels': 16, 'out_channels': 16, 'patch_spatial': 2, 'patch_temporal': 1, 'model_channels': 2048, 'concat_padding_mask': True, 'crossattn_emb_channels': 1024, 'pos_emb_cls': 'rope3d', 'pos_emb_learnable': True, 'pos_emb_interpolation': 'crop', 'min_fps': 1, 'max_fps': 30, 'use_adaln_lora': True, 'adaln_lora_dim': 256, 'num_blocks': 28, 'num_heads': 16, 'extra_per_block_abs_pos_emb': False, 'rope_h_extrapolation_ratio': 4.0, 'rope_w_extrapolation_ratio': 4.0, 'rope_t_extrapolation_ratio': 1.0, 'extra_h_extrapolation_ratio': 1.0, 'extra_w_extrapolation_ratio': 1.0, 'extra_t_extrapolation_ratio': 1.0, 'rope_enable_fps_modulation': False, 'dtype': torch.bfloat16, 'device': None, 'operations': torch.nn} super().__init__(**kwargs) diff --git a/diffsynth/models/flux2_dit.py b/diffsynth/models/flux2_dit.py index a1bd02a..1eecadd 100644 --- a/diffsynth/models/flux2_dit.py +++ b/diffsynth/models/flux2_dit.py @@ -879,6 +879,9 @@ class Flux2Modulation(nn.Module): class Flux2DiT(torch.nn.Module): + + _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] + def __init__( self, patch_size: int = 1, diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 51a6e7f..46fa861 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -275,6 +275,9 @@ class AdaLayerNormContinuous(torch.nn.Module): class FluxDiT(torch.nn.Module): + + _repeated_blocks = ["FluxJointTransformerBlock", "FluxSingleTransformerBlock"] + def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19): super().__init__() self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56]) diff --git a/diffsynth/models/ltx2_dit.py b/diffsynth/models/ltx2_dit.py index 8ce5249..9df0ed3 100644 --- a/diffsynth/models/ltx2_dit.py +++ b/diffsynth/models/ltx2_dit.py @@ -1280,6 +1280,7 @@ class LTXModel(torch.nn.Module): LTX model transformer implementation. This class implements the transformer blocks for the LTX model. """ + _repeated_blocks = ["BasicAVTransformerBlock"] def __init__( # noqa: PLR0913 self, diff --git a/diffsynth/models/qwen_image_dit.py b/diffsynth/models/qwen_image_dit.py index 2dd5143..aeb6dd2 100644 --- a/diffsynth/models/qwen_image_dit.py +++ b/diffsynth/models/qwen_image_dit.py @@ -549,6 +549,9 @@ class QwenImageTransformerBlock(nn.Module): class QwenImageDiT(torch.nn.Module): + + _repeated_blocks = ["QwenImageTransformerBlock"] + def __init__( self, num_layers: int = 60, diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 7e5cec6..52f607e 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -336,6 +336,9 @@ class WanToDanceInjector(nn.Module): class WanModel(torch.nn.Module): + + _repeated_blocks = ["DiTBlock"] + def __init__( self, dim: int, diff --git a/diffsynth/models/z_image_dit.py b/diffsynth/models/z_image_dit.py index 810def2..6a0dc33 100644 --- a/diffsynth/models/z_image_dit.py +++ b/diffsynth/models/z_image_dit.py @@ -326,6 +326,7 @@ class RopeEmbedder: class ZImageDiT(nn.Module): _supports_gradient_checkpointing = True _no_split_modules = ["ZImageTransformerBlock"] + _repeated_blocks = ["ZImageTransformerBlock"] def __init__( self, diff --git a/diffsynth/pipelines/anima_image.py b/diffsynth/pipelines/anima_image.py index 32a3c71..bc4f6cd 100644 --- a/diffsynth/pipelines/anima_image.py +++ b/diffsynth/pipelines/anima_image.py @@ -39,6 +39,7 @@ class AnimaImagePipeline(BasePipeline): AnimaUnit_PromptEmbedder(), ] self.model_fn = model_fn_anima + self.compilable_models = ["dit"] @staticmethod diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 34f4d27..7b6dcc4 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -42,6 +42,7 @@ class Flux2ImagePipeline(BasePipeline): Flux2Unit_ImageIDs(), ] self.model_fn = model_fn_flux2 + self.compilable_models = ["dit"] @staticmethod diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index bfc53e5..db2d522 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -103,6 +103,7 @@ class FluxImagePipeline(BasePipeline): FluxImageUnit_LoRAEncode(), ] self.model_fn = model_fn_flux_image + self.compilable_models = ["dit"] self.lora_loader = FluxLoRALoader def enable_lora_merger(self): diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index 5ef1738..1263b43 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -76,6 +76,7 @@ class LTX2AudioVideoPipeline(BasePipeline): LTX2AudioVideoUnit_SetScheduleStage2(), ] self.model_fn = model_fn_ltx2 + self.compilable_models = ["dit"] self.default_negative_prompt = { "LTX-2": ( diff --git a/diffsynth/pipelines/mova_audio_video.py b/diffsynth/pipelines/mova_audio_video.py index b74a648..d89d3ff 100644 --- a/diffsynth/pipelines/mova_audio_video.py +++ b/diffsynth/pipelines/mova_audio_video.py @@ -52,6 +52,7 @@ class MovaAudioVideoPipeline(BasePipeline): MovaAudioVideoUnit_UnifiedSequenceParallel(), ] self.model_fn = model_fn_mova_audio_video + self.compilable_models = ["video_dit", "video_dit2", "audio_dit"] def enable_usp(self): from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 9677f86..f3256a1 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -56,6 +56,7 @@ class QwenImagePipeline(BasePipeline): QwenImageUnit_BlockwiseControlNet(), ] self.model_fn = model_fn_qwen_image + self.compilable_models = ["dit"] @staticmethod diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index a2da1c4..1c1aa7e 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -83,6 +83,7 @@ class WanVideoPipeline(BasePipeline): WanVideoPostUnit_S2V(), ] self.model_fn = model_fn_wan_video + self.compilable_models = ["dit", "dit2"] def enable_usp(self): diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py index d01f912..59e44b3 100644 --- a/diffsynth/pipelines/z_image.py +++ b/diffsynth/pipelines/z_image.py @@ -54,6 +54,7 @@ class ZImagePipeline(BasePipeline): ZImageUnit_PAIControlNet(), ] self.model_fn = model_fn_z_image + self.compilable_models = ["dit"] @staticmethod