mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-13 04:18:19 +00:00
Compare commits
8 Commits
compatibil
...
webui
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
224060c2a0 | ||
|
|
166e6d2d38 | ||
|
|
5e7e3db0af | ||
|
|
ae8cb139e8 | ||
|
|
e2a3a987da | ||
|
|
f7b9ae7d57 | ||
|
|
5d198287f0 | ||
|
|
5bccd60c80 |
@@ -32,7 +32,7 @@ We believe that a well-developed open-source code framework can lower the thresh
|
|||||||
> DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.
|
> DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.
|
||||||
|
|
||||||
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
||||||
- **January 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
|
- **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
|
||||||
|
|
||||||
- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/).
|
- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/).
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ DiffSynth 目前包括两个开源项目:
|
|||||||
|
|
||||||
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
||||||
|
|
||||||
- **2026年1月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。
|
- **2026年3月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。
|
||||||
|
|
||||||
- **2026年3月12日** 我们新增了 [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) 音视频生成模型的支持,模型支持的功能包括文生音视频、图生音视频、IC-LoRA控制、音频生视频、音视频局部Inpainting,框架支持完整的推理和训练功能。详细信息请参考 [文档](/docs/zh/Model_Details/LTX-2.md) 和 [示例代码](/examples/ltx2/)。
|
- **2026年3月12日** 我们新增了 [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) 音视频生成模型的支持,模型支持的功能包括文生音视频、图生音视频、IC-LoRA控制、音频生视频、音视频局部Inpainting,框架支持完整的推理和训练功能。详细信息请参考 [文档](/docs/zh/Model_Details/LTX-2.md) 和 [示例代码](/examples/ltx2/)。
|
||||||
|
|
||||||
|
|||||||
@@ -604,6 +604,13 @@ z_image_series = [
|
|||||||
"extra_kwargs": {"model_size": "0.6B"},
|
"extra_kwargs": {"model_size": "0.6B"},
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
# To ensure compatibility with the `model.diffusion_model` prefix introduced by other frameworks.
|
||||||
|
"model_hash": "8cf241a0d32f93d5de368502a086852f",
|
||||||
|
"model_name": "z_image_dit",
|
||||||
|
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_dit.ZImageDiTStateDictConverter",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
"""
|
"""
|
||||||
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
|
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
|
||||||
|
|||||||
@@ -339,6 +339,38 @@ class BasePipeline(torch.nn.Module):
|
|||||||
noise_pred = noise_pred_posi
|
noise_pred = noise_pred_posi
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
|
def compile_pipeline(self, mode: str = "default", dynamic: bool = True, fullgraph: bool = False, compile_models: list = None, **kwargs):
|
||||||
|
"""
|
||||||
|
compile the pipeline with torch.compile. The models that will be compiled are determined by the `compilable_models` attribute of the pipeline.
|
||||||
|
If a model has `_repeated_blocks` attribute, we will compile these blocks with regional compilation. Otherwise, we will compile the whole model.
|
||||||
|
See https://docs.pytorch.org/docs/stable/generated/torch.compile.html#torch.compile for details about compilation arguments.
|
||||||
|
Args:
|
||||||
|
mode: The compilation mode, which will be passed to `torch.compile`, options are "default", "reduce-overhead", "max-autotune" and "max-autotune-no-cudagraphs. Default to "default".
|
||||||
|
dynamic: Whether to enable dynamic graph compilation to support dynamic input shapes, which will be passed to `torch.compile`. Default to True (recommended).
|
||||||
|
fullgraph: Whether to use full graph compilation, which will be passed to `torch.compile`. Default to False (recommended).
|
||||||
|
compile_models: The list of model names to be compiled. If None, we will compile the models in `pipeline.compilable_models`. Default to None.
|
||||||
|
**kwargs: Other arguments for `torch.compile`.
|
||||||
|
"""
|
||||||
|
compile_models = compile_models or getattr(self, "compilable_models", [])
|
||||||
|
if len(compile_models) == 0:
|
||||||
|
print("No compilable models in the pipeline. Skip compilation.")
|
||||||
|
return
|
||||||
|
for name in compile_models:
|
||||||
|
model = getattr(self, name, None)
|
||||||
|
if model is None:
|
||||||
|
print(f"Model '{name}' not found in the pipeline.")
|
||||||
|
continue
|
||||||
|
repeated_blocks = getattr(model, "_repeated_blocks", None)
|
||||||
|
# regional compilation for repeated blocks.
|
||||||
|
if repeated_blocks is not None:
|
||||||
|
for submod in model.modules():
|
||||||
|
if submod.__class__.__name__ in repeated_blocks:
|
||||||
|
submod.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs)
|
||||||
|
# compile the whole model.
|
||||||
|
else:
|
||||||
|
model.compile(mode=mode, dynamic=dynamic, fullgraph=fullgraph, **kwargs)
|
||||||
|
print(f"{name} is compiled with mode={mode}, dynamic={dynamic}, fullgraph={fullgraph}.")
|
||||||
|
|
||||||
|
|
||||||
class PipelineUnitGraph:
|
class PipelineUnitGraph:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@@ -1270,6 +1270,9 @@ class LLMAdapter(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class AnimaDiT(MiniTrainDIT):
|
class AnimaDiT(MiniTrainDIT):
|
||||||
|
|
||||||
|
_repeated_blocks = ["Block"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
kwargs = {'image_model': 'anima', 'max_img_h': 240, 'max_img_w': 240, 'max_frames': 128, 'in_channels': 16, 'out_channels': 16, 'patch_spatial': 2, 'patch_temporal': 1, 'model_channels': 2048, 'concat_padding_mask': True, 'crossattn_emb_channels': 1024, 'pos_emb_cls': 'rope3d', 'pos_emb_learnable': True, 'pos_emb_interpolation': 'crop', 'min_fps': 1, 'max_fps': 30, 'use_adaln_lora': True, 'adaln_lora_dim': 256, 'num_blocks': 28, 'num_heads': 16, 'extra_per_block_abs_pos_emb': False, 'rope_h_extrapolation_ratio': 4.0, 'rope_w_extrapolation_ratio': 4.0, 'rope_t_extrapolation_ratio': 1.0, 'extra_h_extrapolation_ratio': 1.0, 'extra_w_extrapolation_ratio': 1.0, 'extra_t_extrapolation_ratio': 1.0, 'rope_enable_fps_modulation': False, 'dtype': torch.bfloat16, 'device': None, 'operations': torch.nn}
|
kwargs = {'image_model': 'anima', 'max_img_h': 240, 'max_img_w': 240, 'max_frames': 128, 'in_channels': 16, 'out_channels': 16, 'patch_spatial': 2, 'patch_temporal': 1, 'model_channels': 2048, 'concat_padding_mask': True, 'crossattn_emb_channels': 1024, 'pos_emb_cls': 'rope3d', 'pos_emb_learnable': True, 'pos_emb_interpolation': 'crop', 'min_fps': 1, 'max_fps': 30, 'use_adaln_lora': True, 'adaln_lora_dim': 256, 'num_blocks': 28, 'num_heads': 16, 'extra_per_block_abs_pos_emb': False, 'rope_h_extrapolation_ratio': 4.0, 'rope_w_extrapolation_ratio': 4.0, 'rope_t_extrapolation_ratio': 1.0, 'extra_h_extrapolation_ratio': 1.0, 'extra_w_extrapolation_ratio': 1.0, 'extra_t_extrapolation_ratio': 1.0, 'rope_enable_fps_modulation': False, 'dtype': torch.bfloat16, 'device': None, 'operations': torch.nn}
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
|
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessor
|
||||||
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
|
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -40,7 +40,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
|
|||||||
value_bias = False
|
value_bias = False
|
||||||
)
|
)
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.processor = DINOv3ViTImageProcessorFast(
|
self.processor = DINOv3ViTImageProcessor(
|
||||||
crop_size = None,
|
crop_size = None,
|
||||||
data_format = "channels_first",
|
data_format = "channels_first",
|
||||||
default_to_square = True,
|
default_to_square = True,
|
||||||
@@ -56,7 +56,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
|
|||||||
0.456,
|
0.456,
|
||||||
0.406
|
0.406
|
||||||
],
|
],
|
||||||
image_processor_type = "DINOv3ViTImageProcessorFast",
|
image_processor_type = "DINOv3ViTImageProcessor",
|
||||||
image_std = [
|
image_std = [
|
||||||
0.229,
|
0.229,
|
||||||
0.224,
|
0.224,
|
||||||
|
|||||||
@@ -879,6 +879,9 @@ class Flux2Modulation(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Flux2DiT(torch.nn.Module):
|
class Flux2DiT(torch.nn.Module):
|
||||||
|
|
||||||
|
_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
patch_size: int = 1,
|
patch_size: int = 1,
|
||||||
|
|||||||
@@ -275,6 +275,9 @@ class AdaLayerNormContinuous(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FluxDiT(torch.nn.Module):
|
class FluxDiT(torch.nn.Module):
|
||||||
|
|
||||||
|
_repeated_blocks = ["FluxJointTransformerBlock", "FluxSingleTransformerBlock"]
|
||||||
|
|
||||||
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
|
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||||
|
|||||||
@@ -1280,6 +1280,7 @@ class LTXModel(torch.nn.Module):
|
|||||||
LTX model transformer implementation.
|
LTX model transformer implementation.
|
||||||
This class implements the transformer blocks for the LTX model.
|
This class implements the transformer blocks for the LTX model.
|
||||||
"""
|
"""
|
||||||
|
_repeated_blocks = ["BasicAVTransformerBlock"]
|
||||||
|
|
||||||
def __init__( # noqa: PLR0913
|
def __init__( # noqa: PLR0913
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -549,6 +549,9 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class QwenImageDiT(torch.nn.Module):
|
class QwenImageDiT(torch.nn.Module):
|
||||||
|
|
||||||
|
_repeated_blocks = ["QwenImageTransformerBlock"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_layers: int = 60,
|
num_layers: int = 60,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig
|
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig
|
||||||
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
|
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessor
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffsynth.core.device.npu_compatible_device import get_device_type
|
from diffsynth.core.device.npu_compatible_device import get_device_type
|
||||||
@@ -90,7 +90,7 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel):
|
|||||||
transformers_version = "4.57.1"
|
transformers_version = "4.57.1"
|
||||||
)
|
)
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.processor = Siglip2ImageProcessorFast(
|
self.processor = Siglip2ImageProcessor(
|
||||||
**{
|
**{
|
||||||
"data_format": "channels_first",
|
"data_format": "channels_first",
|
||||||
"default_to_square": True,
|
"default_to_square": True,
|
||||||
@@ -106,7 +106,7 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel):
|
|||||||
0.5,
|
0.5,
|
||||||
0.5
|
0.5
|
||||||
],
|
],
|
||||||
"image_processor_type": "Siglip2ImageProcessorFast",
|
"image_processor_type": "Siglip2ImageProcessor",
|
||||||
"image_std": [
|
"image_std": [
|
||||||
0.5,
|
0.5,
|
||||||
0.5,
|
0.5,
|
||||||
|
|||||||
@@ -336,6 +336,9 @@ class WanToDanceInjector(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class WanModel(torch.nn.Module):
|
class WanModel(torch.nn.Module):
|
||||||
|
|
||||||
|
_repeated_blocks = ["DiTBlock"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
|
|||||||
@@ -326,6 +326,7 @@ class RopeEmbedder:
|
|||||||
class ZImageDiT(nn.Module):
|
class ZImageDiT(nn.Module):
|
||||||
_supports_gradient_checkpointing = True
|
_supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["ZImageTransformerBlock"]
|
_no_split_modules = ["ZImageTransformerBlock"]
|
||||||
|
_repeated_blocks = ["ZImageTransformerBlock"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ class AnimaImagePipeline(BasePipeline):
|
|||||||
AnimaUnit_PromptEmbedder(),
|
AnimaUnit_PromptEmbedder(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_anima
|
self.model_fn = model_fn_anima
|
||||||
|
self.compilable_models = ["dit"]
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ class Flux2ImagePipeline(BasePipeline):
|
|||||||
Flux2Unit_ImageIDs(),
|
Flux2Unit_ImageIDs(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_flux2
|
self.model_fn = model_fn_flux2
|
||||||
|
self.compilable_models = ["dit"]
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
FluxImageUnit_LoRAEncode(),
|
FluxImageUnit_LoRAEncode(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_flux_image
|
self.model_fn = model_fn_flux_image
|
||||||
|
self.compilable_models = ["dit"]
|
||||||
self.lora_loader = FluxLoRALoader
|
self.lora_loader = FluxLoRALoader
|
||||||
|
|
||||||
def enable_lora_merger(self):
|
def enable_lora_merger(self):
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
LTX2AudioVideoUnit_SetScheduleStage2(),
|
LTX2AudioVideoUnit_SetScheduleStage2(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_ltx2
|
self.model_fn = model_fn_ltx2
|
||||||
|
self.compilable_models = ["dit"]
|
||||||
|
|
||||||
self.default_negative_prompt = {
|
self.default_negative_prompt = {
|
||||||
"LTX-2": (
|
"LTX-2": (
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ class MovaAudioVideoPipeline(BasePipeline):
|
|||||||
MovaAudioVideoUnit_UnifiedSequenceParallel(),
|
MovaAudioVideoUnit_UnifiedSequenceParallel(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_mova_audio_video
|
self.model_fn = model_fn_mova_audio_video
|
||||||
|
self.compilable_models = ["video_dit", "video_dit2", "audio_dit"]
|
||||||
|
|
||||||
def enable_usp(self):
|
def enable_usp(self):
|
||||||
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward
|
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
QwenImageUnit_BlockwiseControlNet(),
|
QwenImageUnit_BlockwiseControlNet(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_qwen_image
|
self.model_fn = model_fn_qwen_image
|
||||||
|
self.compilable_models = ["dit"]
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -83,10 +83,11 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
WanVideoPostUnit_S2V(),
|
WanVideoPostUnit_S2V(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_wan_video
|
self.model_fn = model_fn_wan_video
|
||||||
|
self.compilable_models = ["dit", "dit2"]
|
||||||
|
|
||||||
|
|
||||||
def enable_usp(self):
|
def enable_usp(self):
|
||||||
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward
|
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward, usp_vace_forward
|
||||||
|
|
||||||
for block in self.dit.blocks:
|
for block in self.dit.blocks:
|
||||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
@@ -95,6 +96,14 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
for block in self.dit2.blocks:
|
for block in self.dit2.blocks:
|
||||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
|
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
|
||||||
|
if self.vace is not None:
|
||||||
|
for block in self.vace.vace_blocks:
|
||||||
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
|
self.vace.forward = types.MethodType(usp_vace_forward, self.vace)
|
||||||
|
if self.vace2 is not None:
|
||||||
|
for block in self.vace2.vace_blocks:
|
||||||
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
|
self.vace2.forward = types.MethodType(usp_vace_forward, self.vace2)
|
||||||
self.sp_size = get_sequence_parallel_world_size()
|
self.sp_size = get_sequence_parallel_world_size()
|
||||||
self.use_unified_sequence_parallel = True
|
self.use_unified_sequence_parallel = True
|
||||||
|
|
||||||
@@ -1450,13 +1459,6 @@ def model_fn_wan_video(
|
|||||||
tea_cache_update = tea_cache.check(dit, x, t_mod)
|
tea_cache_update = tea_cache.check(dit, x, t_mod)
|
||||||
else:
|
else:
|
||||||
tea_cache_update = False
|
tea_cache_update = False
|
||||||
|
|
||||||
if vace_context is not None:
|
|
||||||
vace_hints = vace(
|
|
||||||
x, vace_context, context, t_mod, freqs,
|
|
||||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
||||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
|
|
||||||
)
|
|
||||||
|
|
||||||
# WanToDance
|
# WanToDance
|
||||||
if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global:
|
if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global:
|
||||||
@@ -1519,6 +1521,13 @@ def model_fn_wan_video(
|
|||||||
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
|
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
|
||||||
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
|
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
|
||||||
x = chunks[get_sequence_parallel_rank()]
|
x = chunks[get_sequence_parallel_rank()]
|
||||||
|
|
||||||
|
if vace_context is not None:
|
||||||
|
vace_hints = vace(
|
||||||
|
x, vace_context, context, t_mod, freqs,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
|
||||||
|
)
|
||||||
if tea_cache_update:
|
if tea_cache_update:
|
||||||
x = tea_cache.update(x)
|
x = tea_cache.update(x)
|
||||||
else:
|
else:
|
||||||
@@ -1561,9 +1570,6 @@ def model_fn_wan_video(
|
|||||||
# VACE
|
# VACE
|
||||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||||
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
||||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
|
||||||
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
|
||||||
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
|
|
||||||
x = x + current_vace_hint * vace_scale
|
x = x + current_vace_hint * vace_scale
|
||||||
|
|
||||||
# Animate
|
# Animate
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class ZImagePipeline(BasePipeline):
|
|||||||
ZImageUnit_PAIControlNet(),
|
ZImageUnit_PAIControlNet(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_z_image
|
self.model_fn = model_fn_z_image
|
||||||
|
self.compilable_models = ["dit"]
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -94,7 +95,7 @@ class ZImagePipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
# Prompt
|
# Prompt
|
||||||
prompt: str,
|
prompt: str = "",
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
cfg_scale: float = 1.0,
|
cfg_scale: float = 1.0,
|
||||||
# Image
|
# Image
|
||||||
@@ -108,7 +109,7 @@ class ZImagePipeline(BasePipeline):
|
|||||||
width: int = 1024,
|
width: int = 1024,
|
||||||
# Randomness
|
# Randomness
|
||||||
seed: int = None,
|
seed: int = None,
|
||||||
rand_device: str = "cpu",
|
rand_device: Union[str, torch.device] = "cpu",
|
||||||
# Steps
|
# Steps
|
||||||
num_inference_steps: int = 8,
|
num_inference_steps: int = 8,
|
||||||
sigma_shift: float = None,
|
sigma_shift: float = None,
|
||||||
|
|||||||
3
diffsynth/utils/state_dict_converters/z_image_dit.py
Normal file
3
diffsynth/utils/state_dict_converters/z_image_dit.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
def ZImageDiTStateDictConverter(state_dict):
|
||||||
|
state_dict_ = {name.replace("model.diffusion_model.", ""): state_dict[name] for name in state_dict}
|
||||||
|
return state_dict_
|
||||||
@@ -1 +1 @@
|
|||||||
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp, get_current_chunk, gather_all_chunks
|
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, usp_vace_forward, get_sequence_parallel_world_size, initialize_usp, get_current_chunk, gather_all_chunks
|
||||||
|
|||||||
@@ -117,6 +117,39 @@ def usp_dit_forward(self,
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def usp_vace_forward(
|
||||||
|
self, x, vace_context, context, t_mod, freqs,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
|
):
|
||||||
|
# Compute full sequence length from the sharded x
|
||||||
|
full_seq_len = x.shape[1] * get_sequence_parallel_world_size()
|
||||||
|
|
||||||
|
# Embed vace_context via patch embedding
|
||||||
|
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||||
|
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||||
|
c = torch.cat([
|
||||||
|
torch.cat([u, u.new_zeros(1, full_seq_len - u.size(1), u.size(2))],
|
||||||
|
dim=1) for u in c
|
||||||
|
])
|
||||||
|
|
||||||
|
# Chunk VACE context along sequence dim BEFORE processing through blocks
|
||||||
|
c = torch.chunk(c, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||||
|
|
||||||
|
# Process through vace_blocks (self_attn already monkey-patched to usp_attn_forward)
|
||||||
|
for block in self.vace_blocks:
|
||||||
|
c = gradient_checkpoint_forward(
|
||||||
|
block,
|
||||||
|
use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload,
|
||||||
|
c, x, context, t_mod, freqs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hints are already sharded per-rank
|
||||||
|
hints = torch.unbind(c)[:-1]
|
||||||
|
return hints
|
||||||
|
|
||||||
|
|
||||||
def usp_attn_forward(self, x, freqs):
|
def usp_attn_forward(self, x, freqs):
|
||||||
q = self.norm_q(self.q(x))
|
q = self.norm_q(self.q(x))
|
||||||
k = self.norm_k(self.k(x))
|
k = self.norm_k(self.k(x))
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ For more information about installation, please refer to [Installation Dependenc
|
|||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
Run the following code to quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model and perform inference. VRAM management has been enabled, and the framework will automatically control model parameter loading based on remaining VRAM. It can run with a minimum of 8GB VRAM.
|
Run the following code to quickly load the [Lightricks/LTX-2.3](https://www.modelscope.cn/models/Lightricks/LTX-2.3) model and perform inference. VRAM management has been enabled, and the framework will automatically control model parameter loading based on remaining VRAM. It can run with a minimum of 8GB VRAM.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
@@ -24,88 +24,36 @@ from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelCo
|
|||||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
vram_config = {
|
vram_config = {
|
||||||
"offload_dtype": torch.float8_e5m2,
|
"offload_dtype": torch.bfloat16,
|
||||||
"offload_device": "cpu",
|
"offload_device": "cpu",
|
||||||
"onload_dtype": torch.float8_e5m2,
|
"onload_dtype": torch.bfloat16,
|
||||||
"onload_device": "cpu",
|
"onload_device": "cuda",
|
||||||
"preparing_dtype": torch.float8_e5m2,
|
"preparing_dtype": torch.bfloat16,
|
||||||
"preparing_device": "cuda",
|
"preparing_device": "cuda",
|
||||||
"computation_dtype": torch.bfloat16,
|
"computation_dtype": torch.bfloat16,
|
||||||
"computation_device": "cuda",
|
"computation_device": "cuda",
|
||||||
}
|
}
|
||||||
"""
|
|
||||||
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
|
|
||||||
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
|
|
||||||
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
|
|
||||||
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
|
|
||||||
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
|
|
||||||
and avoid redundant memory usage when users only want to use part of the model.
|
|
||||||
"""
|
|
||||||
# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading
|
|
||||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled-lora-384.safetensors"),
|
||||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
|
||||||
)
|
)
|
||||||
|
prompt = "Two cute orange cats, wearing boxing gloves, stand in a boxing ring and fight each other. They are punching each other fast and yelling: 'I will win!'"
|
||||||
# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2"
|
negative_prompt = pipe.default_negative_prompt["LTX-2.3"]
|
||||||
# pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|
||||||
# torch_dtype=torch.bfloat16,
|
|
||||||
# device="cuda",
|
|
||||||
# model_configs=[
|
|
||||||
# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
|
||||||
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
|
||||||
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
|
||||||
# ],
|
|
||||||
# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
|
||||||
# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
|
||||||
# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
|
||||||
# )
|
|
||||||
|
|
||||||
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
|
||||||
negative_prompt = (
|
|
||||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
|
||||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
|
||||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
|
||||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
|
||||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
|
||||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
|
||||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
|
||||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
|
||||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
|
||||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
|
||||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
|
||||||
)
|
|
||||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
|
||||||
video, audio = pipe(
|
video, audio = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
seed=43,
|
seed=43,
|
||||||
height=height,
|
height=1024, width=1536, num_frames=121,
|
||||||
width=width,
|
tiled=True, use_two_stage_pipeline=True,
|
||||||
num_frames=num_frames,
|
|
||||||
tiled=True,
|
|
||||||
use_two_stage_pipeline=True,
|
|
||||||
)
|
|
||||||
write_video_audio_ltx2(
|
|
||||||
video=video,
|
|
||||||
audio=audio,
|
|
||||||
output_path='ltx2_twostage.mp4',
|
|
||||||
fps=24,
|
|
||||||
audio_sample_rate=24000,
|
|
||||||
)
|
)
|
||||||
|
write_video_audio_ltx2(video=video, audio=audio, output_path='video.mp4', fps=24, audio_sample_rate=pipe.audio_vocoder.output_sampling_rate)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Model Overview
|
## Model Overview
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ pip install -e .
|
|||||||
|
|
||||||
## 快速开始
|
## 快速开始
|
||||||
|
|
||||||
运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
|
运行以下代码可以快速加载 [Lightricks/LTX-2.3](https://www.modelscope.cn/models/Lightricks/LTX-2.3) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
@@ -24,88 +24,36 @@ from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelCo
|
|||||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
vram_config = {
|
vram_config = {
|
||||||
"offload_dtype": torch.float8_e5m2,
|
"offload_dtype": torch.bfloat16,
|
||||||
"offload_device": "cpu",
|
"offload_device": "cpu",
|
||||||
"onload_dtype": torch.float8_e5m2,
|
"onload_dtype": torch.bfloat16,
|
||||||
"onload_device": "cpu",
|
"onload_device": "cuda",
|
||||||
"preparing_dtype": torch.float8_e5m2,
|
"preparing_dtype": torch.bfloat16,
|
||||||
"preparing_device": "cuda",
|
"preparing_device": "cuda",
|
||||||
"computation_dtype": torch.bfloat16,
|
"computation_dtype": torch.bfloat16,
|
||||||
"computation_device": "cuda",
|
"computation_device": "cuda",
|
||||||
}
|
}
|
||||||
"""
|
|
||||||
Offical model repo: https://www.modelscope.cn/models/Lightricks/LTX-2
|
|
||||||
Repackaged model repo: https://www.modelscope.cn/models/DiffSynth-Studio/LTX-2-Repackage
|
|
||||||
For base models of LTX-2, offical checkpoint (with model config ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors"))
|
|
||||||
and repackaged checkpoints (with model config ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="*.safetensors")) are both supported.
|
|
||||||
We have repackeged the official checkpoints in DiffSynth-Studio/LTX-2-Repackage repo to support separate loading of different submodules,
|
|
||||||
and avoid redundant memory usage when users only want to use part of the model.
|
|
||||||
"""
|
|
||||||
# use the repackaged modelconfig from "DiffSynth-Studio/LTX-2-Repackage" to avoid redundant model loading
|
|
||||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled-lora-384.safetensors"),
|
||||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
|
||||||
)
|
)
|
||||||
|
prompt = "Two cute orange cats, wearing boxing gloves, stand in a boxing ring and fight each other. They are punching each other fast and yelling: 'I will win!'"
|
||||||
# use the following modelconfig if you want to initialize model from offical checkpoints from "Lightricks/LTX-2"
|
negative_prompt = pipe.default_negative_prompt["LTX-2.3"]
|
||||||
# pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|
||||||
# torch_dtype=torch.bfloat16,
|
|
||||||
# device="cuda",
|
|
||||||
# model_configs=[
|
|
||||||
# ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
|
||||||
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
|
||||||
# ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
|
||||||
# ],
|
|
||||||
# tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
|
||||||
# stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
|
||||||
# vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
|
||||||
# )
|
|
||||||
|
|
||||||
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
|
||||||
negative_prompt = (
|
|
||||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
|
||||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
|
||||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
|
||||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
|
||||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
|
||||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
|
||||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
|
||||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
|
||||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
|
||||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
|
||||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
|
||||||
)
|
|
||||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
|
||||||
video, audio = pipe(
|
video, audio = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
seed=43,
|
seed=43,
|
||||||
height=height,
|
height=1024, width=1536, num_frames=121,
|
||||||
width=width,
|
tiled=True, use_two_stage_pipeline=True,
|
||||||
num_frames=num_frames,
|
|
||||||
tiled=True,
|
|
||||||
use_two_stage_pipeline=True,
|
|
||||||
)
|
|
||||||
write_video_audio_ltx2(
|
|
||||||
video=video,
|
|
||||||
audio=audio,
|
|
||||||
output_path='ltx2_twostage.mp4',
|
|
||||||
fps=24,
|
|
||||||
audio_sample_rate=24000,
|
|
||||||
)
|
)
|
||||||
|
write_video_audio_ltx2(video=video, audio=audio, output_path='video.mp4', fps=24, audio_sample_rate=pipe.audio_vocoder.output_sampling_rate)
|
||||||
```
|
```
|
||||||
|
|
||||||
## 模型总览
|
## 模型总览
|
||||||
|
|||||||
283
examples/dev_tools/webui.py
Normal file
283
examples/dev_tools/webui.py
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
import importlib, inspect, pkgutil, traceback, torch, os, re
|
||||||
|
from typing import Union, List, Optional, Tuple, Iterable, Dict
|
||||||
|
from contextlib import contextmanager
|
||||||
|
import streamlit as st
|
||||||
|
from diffsynth import ModelConfig
|
||||||
|
from diffsynth.diffusion.base_pipeline import ControlNetInput
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
st.set_page_config(layout="wide")
|
||||||
|
|
||||||
|
class StreamlitTqdmWrapper:
|
||||||
|
"""Wrapper class that combines tqdm and streamlit progress bar"""
|
||||||
|
def __init__(self, iterable, st_progress_bar=None):
|
||||||
|
self.iterable = iterable
|
||||||
|
self.st_progress_bar = st_progress_bar
|
||||||
|
self.tqdm_bar = tqdm(iterable)
|
||||||
|
self.total = len(iterable) if hasattr(iterable, '__len__') else None
|
||||||
|
self.current = 0
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for item in self.tqdm_bar:
|
||||||
|
if self.st_progress_bar is not None and self.total is not None:
|
||||||
|
self.current += 1
|
||||||
|
self.st_progress_bar.progress(self.current / self.total)
|
||||||
|
yield item
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
if hasattr(self.tqdm_bar, '__exit__'):
|
||||||
|
self.tqdm_bar.__exit__(*args)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def catch_error(error_value):
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
except Exception as e:
|
||||||
|
error_message = traceback.format_exc()
|
||||||
|
print(f"Error {error_value}:\n{error_message}")
|
||||||
|
|
||||||
|
def parse_model_configs_from_an_example(path):
|
||||||
|
model_configs = []
|
||||||
|
with open(path, "r") as f:
|
||||||
|
for code in f.readlines():
|
||||||
|
code = code.strip()
|
||||||
|
if not code.startswith("ModelConfig"):
|
||||||
|
continue
|
||||||
|
pairs = re.findall(r'(\w+)\s*=\s*["\']([^"\']+)["\']', code)
|
||||||
|
config_dict = {k: v for k, v in pairs}
|
||||||
|
model_configs.append(ModelConfig(model_id=config_dict["model_id"], origin_file_pattern=config_dict["origin_file_pattern"]))
|
||||||
|
return model_configs
|
||||||
|
|
||||||
|
def list_examples(path, keyword=None):
|
||||||
|
examples = []
|
||||||
|
if os.path.isdir(path):
|
||||||
|
for file_name in os.listdir(path):
|
||||||
|
examples.extend(list_examples(os.path.join(path, file_name), keyword=keyword))
|
||||||
|
elif path.endswith(".py"):
|
||||||
|
with open(path, "r") as f:
|
||||||
|
code = f.read()
|
||||||
|
if keyword is None or keyword in code:
|
||||||
|
examples.extend([path])
|
||||||
|
return examples
|
||||||
|
|
||||||
|
def parse_available_pipelines():
|
||||||
|
from diffsynth.diffusion.base_pipeline import BasePipeline
|
||||||
|
import diffsynth.pipelines as _pipelines_pkg
|
||||||
|
available_pipelines = {}
|
||||||
|
for _, name, _ in pkgutil.iter_modules(_pipelines_pkg.__path__):
|
||||||
|
with catch_error(f"Failed: import diffsynth.pipelines.{name}"):
|
||||||
|
mod = importlib.import_module(f"diffsynth.pipelines.{name}")
|
||||||
|
classes = {
|
||||||
|
cls_name: cls for cls_name, cls in inspect.getmembers(mod, inspect.isclass)
|
||||||
|
if issubclass(cls, BasePipeline) and cls is not BasePipeline and cls.__module__ == mod.__name__
|
||||||
|
}
|
||||||
|
available_pipelines.update(classes)
|
||||||
|
return available_pipelines
|
||||||
|
|
||||||
|
def parse_available_examples(path, available_pipelines):
|
||||||
|
available_examples = {}
|
||||||
|
for pipeline_name in available_pipelines:
|
||||||
|
examples = ["None"] + list_examples(path, keyword=f"{pipeline_name}.from_pretrained")
|
||||||
|
available_examples[pipeline_name] = examples
|
||||||
|
return available_examples
|
||||||
|
|
||||||
|
def draw_selectbox(label, options, option_map, value=None, disabled=False):
|
||||||
|
default_index = 0 if value is None else tuple(options).index([option for option in option_map if option_map[option]==value][0])
|
||||||
|
option = st.selectbox(label=label, options=tuple(options), index=default_index, disabled=disabled)
|
||||||
|
return option_map.get(option)
|
||||||
|
|
||||||
|
def parse_params(fn):
|
||||||
|
params = []
|
||||||
|
for name, param in inspect.signature(fn).parameters.items():
|
||||||
|
annotation = param.annotation if param.annotation is not inspect.Parameter.empty else None
|
||||||
|
default = param.default if param.default is not inspect.Parameter.empty else None
|
||||||
|
params.append({"name": name, "dtype": annotation, "value": default})
|
||||||
|
return params
|
||||||
|
|
||||||
|
def draw_model_config(model_config=None, key_suffix="", disabled=False):
|
||||||
|
with st.container(border=True):
|
||||||
|
if model_config is None:
|
||||||
|
model_config = ModelConfig()
|
||||||
|
path = st.text_input(label="path", key="path" + key_suffix, value=model_config.path, disabled=disabled)
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
with col1:
|
||||||
|
model_id = st.text_input(label="model_id", key="model_id" + key_suffix, value=model_config.model_id, disabled=disabled)
|
||||||
|
with col2:
|
||||||
|
origin_file_pattern = st.text_input(label="origin_file_pattern", key="origin_file_pattern" + key_suffix, value=model_config.origin_file_pattern, disabled=disabled)
|
||||||
|
model_config = ModelConfig(
|
||||||
|
path=None if path == "" else path,
|
||||||
|
model_id=model_id,
|
||||||
|
origin_file_pattern=origin_file_pattern,
|
||||||
|
)
|
||||||
|
return model_config
|
||||||
|
|
||||||
|
def draw_multi_model_config(name="", value=None, disabled=False):
|
||||||
|
model_configs = []
|
||||||
|
with st.container(border=True):
|
||||||
|
st.markdown(name)
|
||||||
|
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
|
||||||
|
for i in range(num):
|
||||||
|
model_config = draw_model_config(key_suffix=f"_{name}_{i}", model_config=None if value is None else value[i], disabled=disabled)
|
||||||
|
model_configs.append(model_config)
|
||||||
|
return model_configs
|
||||||
|
|
||||||
|
def draw_single_model_config(name="", value=None, disabled=False):
|
||||||
|
with st.container(border=True):
|
||||||
|
st.markdown(name)
|
||||||
|
model_config = draw_model_config(value, key_suffix=f"_{name}", disabled=disabled)
|
||||||
|
return model_config
|
||||||
|
|
||||||
|
def draw_multi_images(name="", value=None, disabled=False):
|
||||||
|
images = []
|
||||||
|
with st.container(border=True):
|
||||||
|
st.markdown(name)
|
||||||
|
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
|
||||||
|
for i in range(num):
|
||||||
|
image = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], key=f"{name}_{i}", disabled=disabled)
|
||||||
|
if image is not None: images.append(Image.open(image))
|
||||||
|
return images
|
||||||
|
|
||||||
|
def draw_controlnet_input(name="", value=None, disabled=False):
|
||||||
|
with st.container(border=True):
|
||||||
|
st.markdown(name)
|
||||||
|
controlnet_id = st.number_input("controlnet_id", value=0, min_value=0, max_value=20, step=1, key=f"{name}_controlnet_id")
|
||||||
|
scale = st.number_input("scale", value=1.0, min_value=0.0, max_value=10.0, key=f"{name}_scale")
|
||||||
|
image = st.file_uploader("image", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_image")
|
||||||
|
if image is not None: image = Image.open(image)
|
||||||
|
inpaint_image = st.file_uploader("inpaint_image", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_inpaint_image")
|
||||||
|
if inpaint_image is not None: inpaint_image = Image.open(inpaint_image)
|
||||||
|
inpaint_mask = st.file_uploader("inpaint_mask", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_inpaint_mask")
|
||||||
|
if inpaint_mask is not None: inpaint_mask = Image.open(inpaint_mask)
|
||||||
|
return ControlNetInput(controlnet_id=controlnet_id, scale=scale, image=image, inpaint_image=inpaint_image, inpaint_mask=inpaint_mask)
|
||||||
|
|
||||||
|
def draw_controlnet_inputs(name, value=None, disabled=False):
|
||||||
|
controlnet_inputs = []
|
||||||
|
with st.container(border=True):
|
||||||
|
st.markdown(name)
|
||||||
|
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
|
||||||
|
for i in range(num):
|
||||||
|
controlnet_input = draw_controlnet_input(name=f"{name}_{i}", value=None, disabled=disabled)
|
||||||
|
controlnet_inputs.append(controlnet_input)
|
||||||
|
return controlnet_inputs
|
||||||
|
|
||||||
|
def draw_ui_element(name, dtype, value):
|
||||||
|
unsupported_dtype = [
|
||||||
|
Dict[str, torch.Tensor],
|
||||||
|
torch.Tensor,
|
||||||
|
]
|
||||||
|
if dtype in unsupported_dtype:
|
||||||
|
return
|
||||||
|
if value is None:
|
||||||
|
with st.container(border=True):
|
||||||
|
enable = st.checkbox(f"Enable {name}", value=False)
|
||||||
|
ui = draw_ui_element_safely(name, dtype, value, disabled=not enable)
|
||||||
|
if enable:
|
||||||
|
return ui
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return draw_ui_element_safely(name, dtype, value)
|
||||||
|
|
||||||
|
def draw_ui_element_safely(name, dtype, value, disabled=False):
|
||||||
|
if dtype == torch.dtype:
|
||||||
|
option_map = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
|
||||||
|
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
|
||||||
|
elif dtype == Union[str, torch.device]:
|
||||||
|
option_map = {"cuda": "cuda", "cpu": "cpu"}
|
||||||
|
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
|
||||||
|
elif dtype == bool:
|
||||||
|
ui = st.checkbox(name, value, disabled=disabled)
|
||||||
|
elif dtype == ModelConfig:
|
||||||
|
ui = draw_single_model_config(name, value, disabled=disabled)
|
||||||
|
elif dtype == list[ModelConfig]:
|
||||||
|
if name == "model_configs" and "model_configs_from_example" in st.session_state:
|
||||||
|
model_configs = st.session_state["model_configs_from_example"]
|
||||||
|
del st.session_state["model_configs_from_example"]
|
||||||
|
ui = draw_multi_model_config(name, model_configs, disabled=disabled)
|
||||||
|
else:
|
||||||
|
ui = draw_multi_model_config(name, disabled=disabled)
|
||||||
|
elif dtype == str:
|
||||||
|
if "prompt" in name:
|
||||||
|
ui = st.text_area(name, value, height=3, disabled=disabled)
|
||||||
|
else:
|
||||||
|
ui = st.text_input(name, value, disabled=disabled)
|
||||||
|
elif dtype == float:
|
||||||
|
ui = st.number_input(name, value, disabled=disabled)
|
||||||
|
elif dtype == int:
|
||||||
|
ui = st.number_input(name, value, step=1, disabled=disabled)
|
||||||
|
elif dtype == Image.Image:
|
||||||
|
ui = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], disabled=disabled)
|
||||||
|
if ui is not None: ui = Image.open(ui)
|
||||||
|
elif dtype == List[Image.Image]:
|
||||||
|
ui = draw_multi_images(name, value, disabled=disabled)
|
||||||
|
elif dtype == List[ControlNetInput]:
|
||||||
|
ui = draw_controlnet_inputs(name, value, disabled=disabled)
|
||||||
|
elif dtype is None:
|
||||||
|
if name == "progress_bar_cmd":
|
||||||
|
ui = value
|
||||||
|
else:
|
||||||
|
st.markdown(f"(`{name}` is not not configurable in WebUI). dtype: `{dtype}`.")
|
||||||
|
ui = value
|
||||||
|
return ui
|
||||||
|
|
||||||
|
|
||||||
|
def launch_webui():
|
||||||
|
input_col, output_col = st.columns(2)
|
||||||
|
with input_col:
|
||||||
|
if "available_pipelines" not in st.session_state:
|
||||||
|
st.session_state["available_pipelines"] = parse_available_pipelines()
|
||||||
|
if "available_examples" not in st.session_state:
|
||||||
|
st.session_state["available_examples"] = parse_available_examples("./examples", st.session_state["available_pipelines"])
|
||||||
|
|
||||||
|
with st.expander("Pipeline", expanded=True):
|
||||||
|
pipeline_class = draw_selectbox("Pipeline Class", st.session_state["available_pipelines"].keys(), st.session_state["available_pipelines"], value=st.session_state["available_pipelines"]["ZImagePipeline"])
|
||||||
|
example = st.selectbox("Parse model configs from an example (optional)", st.session_state["available_examples"][pipeline_class.__name__])
|
||||||
|
if example != "None":
|
||||||
|
st.session_state["model_configs_from_example"] = parse_model_configs_from_an_example(example)
|
||||||
|
if st.button("Step 1: Parse Pipeline", type="primary"):
|
||||||
|
st.session_state["pipeline_class"] = pipeline_class
|
||||||
|
|
||||||
|
if "pipeline_class" not in st.session_state:
|
||||||
|
return
|
||||||
|
with st.expander("Model", expanded=True):
|
||||||
|
input_params = {}
|
||||||
|
params = parse_params(pipeline_class.from_pretrained)
|
||||||
|
for param in params:
|
||||||
|
input_params[param["name"]] = draw_ui_element(**param)
|
||||||
|
if st.button("Step 2: Load Models", type="primary"):
|
||||||
|
with st.spinner("Loading models", show_time=True):
|
||||||
|
if "pipe" in st.session_state:
|
||||||
|
del st.session_state["pipe"]
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
st.session_state["pipe"] = pipeline_class.from_pretrained(**input_params)
|
||||||
|
|
||||||
|
if "pipe" not in st.session_state:
|
||||||
|
return
|
||||||
|
with st.expander("Input", expanded=True):
|
||||||
|
pipe = st.session_state["pipe"]
|
||||||
|
input_params = {}
|
||||||
|
params = parse_params(pipe.__call__)
|
||||||
|
for param in params:
|
||||||
|
if param["name"] in ["self"]:
|
||||||
|
continue
|
||||||
|
input_params[param["name"]] = draw_ui_element(**param)
|
||||||
|
|
||||||
|
with output_col:
|
||||||
|
if st.button("Step 3: Generate", type="primary"):
|
||||||
|
if "progress_bar_cmd" in input_params:
|
||||||
|
input_params["progress_bar_cmd"] = lambda iterable: StreamlitTqdmWrapper(iterable, st.progress(0))
|
||||||
|
result = pipe(**input_params)
|
||||||
|
st.session_state["result"] = result
|
||||||
|
|
||||||
|
if "result" in st.session_state:
|
||||||
|
result = st.session_state["result"]
|
||||||
|
if isinstance(result, Image.Image):
|
||||||
|
st.image(result)
|
||||||
|
else:
|
||||||
|
print(f"unsupported result format: {result}")
|
||||||
|
|
||||||
|
launch_webui()
|
||||||
|
# streamlit run examples/dev_tools/webui.py --server.fileWatcherType none
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
vram_config = {
|
vram_config = {
|
||||||
"offload_dtype": torch.bfloat16,
|
"offload_dtype": torch.bfloat16,
|
||||||
@@ -25,3 +25,8 @@ pipe = Flux2ImagePipeline.from_pretrained(
|
|||||||
prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
|
prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
|
||||||
image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
|
image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
|
||||||
image.save("image_FLUX.2-dev.jpg")
|
image.save("image_FLUX.2-dev.jpg")
|
||||||
|
|
||||||
|
prompt = "Transform the image into Japanese anime style"
|
||||||
|
edit_image = [Image.open("image_FLUX.2-dev.jpg")]
|
||||||
|
image = pipe(prompt, seed=42, rand_device="cuda", edit_image=edit_image, num_inference_steps=50, embedded_guidance=2.5)
|
||||||
|
image.save("image_FLUX.2-dev_edit.jpg")
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
vram_config = {
|
vram_config = {
|
||||||
"offload_dtype": "disk",
|
"offload_dtype": "disk",
|
||||||
@@ -24,4 +25,9 @@ pipe = Flux2ImagePipeline.from_pretrained(
|
|||||||
)
|
)
|
||||||
prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene."
|
prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene."
|
||||||
image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
|
image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
|
||||||
image.save("image.jpg")
|
image.save("image.jpg")
|
||||||
|
|
||||||
|
prompt = "Transform the image into Japanese anime style"
|
||||||
|
edit_image = [Image.open("image.jpg")]
|
||||||
|
image = pipe(prompt, seed=42, rand_device="cuda", edit_image=edit_image, num_inference_steps=50, embedded_guidance=2.5)
|
||||||
|
image.save("image_edit.jpg")
|
||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "diffsynth"
|
name = "diffsynth"
|
||||||
version = "2.0.6"
|
version = "2.0.7"
|
||||||
description = "Enjoy the magic of Diffusion models!"
|
description = "Enjoy the magic of Diffusion models!"
|
||||||
authors = [{name = "ModelScope Team"}]
|
authors = [{name = "ModelScope Team"}]
|
||||||
license = {text = "Apache-2.0"}
|
license = {text = "Apache-2.0"}
|
||||||
|
|||||||
Reference in New Issue
Block a user