DiffSynth-Studio 2.0 major update

This commit is contained in:
root
2025-12-04 16:33:07 +08:00
parent afd101f345
commit 72af7122b3
758 changed files with 26462 additions and 2221398 deletions

View File

@@ -1,48 +1,18 @@
import torch
import torch, math
from PIL import Image
from typing import Union
from PIL import Image
from tqdm import tqdm
from einops import rearrange
import numpy as np
from ..models import ModelManager, load_state_dict
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
from ..models.qwen_image_dit import QwenImageDiT
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
from ..models.qwen_image_vae import QwenImageVAE
from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
from ..schedulers import FlowMatchScheduler
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora import GeneralLoRALoader
from .flux_image_new import ControlNetInput
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
class QwenImageBlockwiseMultiControlNet(torch.nn.Module):
def __init__(self, models: list[QwenImageBlockWiseControlNet]):
super().__init__()
if not isinstance(models, list):
models = [models]
self.models = torch.nn.ModuleList(models)
def preprocess(self, controlnet_inputs: list[ControlNetInput], conditionings: list[torch.Tensor], **kwargs):
processed_conditionings = []
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
conditioning = rearrange(conditioning, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
model_output = self.models[controlnet_input.controlnet_id].process_controlnet_conditioning(conditioning)
processed_conditionings.append(model_output)
return processed_conditionings
def blockwise_forward(self, image, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, block_id, **kwargs):
res = 0
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)
if progress > controlnet_input.start + (1e-4) or progress < controlnet_input.end - (1e-4):
continue
model_output = self.models[controlnet_input.controlnet_id].blockwise_forward(image, conditioning, block_id)
res = res + model_output * controlnet_input.scale
return res
class QwenImagePipeline(BasePipeline):
@@ -54,14 +24,13 @@ class QwenImagePipeline(BasePipeline):
)
from transformers import Qwen2Tokenizer, Qwen2VLProcessor
self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02)
self.scheduler = FlowMatchScheduler("Qwen-Image")
self.text_encoder: QwenImageTextEncoder = None
self.dit: QwenImageDiT = None
self.vae: QwenImageVAE = None
self.blockwise_controlnet: QwenImageBlockwiseMultiControlNet = None
self.tokenizer: Qwen2Tokenizer = None
self.processor: Qwen2VLProcessor = None
self.unit_runner = PipelineUnitRunner()
self.in_iteration_models = ("dit", "blockwise_controlnet")
self.units = [
QwenImageUnit_ShapeChecker(),
@@ -75,245 +44,6 @@ class QwenImagePipeline(BasePipeline):
QwenImageUnit_BlockwiseControlNet(),
]
self.model_fn = model_fn_qwen_image
def load_lora(
self,
module: torch.nn.Module,
lora_config: Union[ModelConfig, str] = None,
alpha=1,
hotload=False,
state_dict=None,
):
if state_dict is None:
if isinstance(lora_config, str):
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
else:
lora_config.download_if_necessary()
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
else:
lora = state_dict
if hotload:
for name, module in module.named_modules():
if isinstance(module, AutoWrappedLinear):
lora_a_name = f'{name}.lora_A.default.weight'
lora_b_name = f'{name}.lora_B.default.weight'
if lora_a_name in lora and lora_b_name in lora:
module.lora_A_weights.append(lora[lora_a_name] * alpha)
module.lora_B_weights.append(lora[lora_b_name])
else:
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
loader.load(module, lora, alpha=alpha)
def clear_lora(self):
for name, module in self.named_modules():
if isinstance(module, AutoWrappedLinear):
if hasattr(module, "lora_A_weights"):
module.lora_A_weights.clear()
if hasattr(module, "lora_B_weights"):
module.lora_B_weights.clear()
def enable_lora_magic(self):
if self.dit is not None:
if not (hasattr(self.dit, "vram_management_enabled") and self.dit.vram_management_enabled):
dtype = next(iter(self.dit.parameters())).dtype
enable_vram_management(
self.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device=self.device,
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=None,
)
def training_loss(self, **inputs):
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
noise = torch.randn_like(inputs["input_latents"])
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = self.scheduler.training_target(inputs["input_latents"], noise, timestep)
noise_pred = self.model_fn(**inputs, timestep=timestep)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.scheduler.training_weight(timestep)
return loss
def direct_distill_loss(self, **inputs):
self.scheduler.set_timesteps(inputs["num_inference_steps"])
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(self.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
inputs["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
return loss
def _enable_fp8_lora_training(self, dtype):
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding
from ..models.qwen_image_dit import RMSNorm
from ..models.qwen_image_vae import QwenImageRMS_norm
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.Embedding: AutoWrappedModule,
Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
Qwen2RMSNorm: AutoWrappedModule,
Qwen2_5_VisionPatchEmbed: AutoWrappedModule,
Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule,
QwenImageRMS_norm: AutoWrappedModule,
}
model_config = dict(
offload_dtype=dtype,
offload_device="cuda",
onload_dtype=dtype,
onload_device="cuda",
computation_dtype=self.torch_dtype,
computation_device="cuda",
)
if self.text_encoder is not None:
enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config)
if self.dit is not None:
enable_vram_management(self.dit, module_map=module_map, module_config=model_config)
if self.vae is not None:
enable_vram_management(self.vae, module_map=module_map, module_config=model_config)
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, auto_offload=True, enable_dit_fp8_computation=False):
self.vram_management_enabled = True
if vram_limit is None and auto_offload:
vram_limit = self.get_vram()
if vram_limit is not None:
vram_limit = vram_limit - vram_buffer
if self.text_encoder is not None:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding
dtype = next(iter(self.text_encoder.parameters())).dtype
enable_vram_management(
self.text_encoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
Qwen2_5_VLRotaryEmbedding: AutoWrappedModule,
Qwen2RMSNorm: AutoWrappedModule,
Qwen2_5_VisionPatchEmbed: AutoWrappedModule,
Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.dit is not None:
from ..models.qwen_image_dit import RMSNorm
dtype = next(iter(self.dit.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device
if not enable_dit_fp8_computation:
enable_vram_management(
self.dit,
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
else:
enable_vram_management(
self.dit,
module_map = {
RMSNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
enable_vram_management(
self.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.vae is not None:
from ..models.qwen_image_vae import QwenImageRMS_norm
dtype = next(iter(self.vae.parameters())).dtype
enable_vram_management(
self.vae,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv3d: AutoWrappedModule,
torch.nn.Conv2d: AutoWrappedModule,
QwenImageRMS_norm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.blockwise_controlnet is not None:
enable_vram_management(
self.blockwise_controlnet,
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
@staticmethod
@@ -323,24 +53,18 @@ class QwenImagePipeline(BasePipeline):
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
processor_config: ModelConfig = None,
vram_limit: float = None,
):
# Download and load models
model_manager = ModelManager()
for model_config in model_configs:
model_config.download_if_necessary()
model_manager.load_model(
model_config.path,
device=model_config.offload_device or device,
torch_dtype=model_config.offload_dtype or torch_dtype
)
# Initialize pipeline
pipe = QwenImagePipeline(device=device, torch_dtype=torch_dtype)
pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder")
pipe.dit = model_manager.fetch_model("qwen_image_dit")
pipe.vae = model_manager.fetch_model("qwen_image_vae")
pipe.blockwise_controlnet = QwenImageBlockwiseMultiControlNet(model_manager.fetch_model("qwen_image_blockwise_controlnet", index="all"))
if tokenizer_config is not None and pipe.text_encoder is not None:
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("qwen_image_text_encoder")
pipe.dit = model_pool.fetch_model("qwen_image_dit")
pipe.vae = model_pool.fetch_model("qwen_image_vae")
pipe.blockwise_controlnet = QwenImageBlockwiseMultiControlNet(model_pool.fetch_model("qwen_image_blockwise_controlnet", index="all"))
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
from transformers import Qwen2Tokenizer
pipe.tokenizer = Qwen2Tokenizer.from_pretrained(tokenizer_config.path)
@@ -348,6 +72,9 @@ class QwenImagePipeline(BasePipeline):
processor_config.download_if_necessary()
from transformers import Qwen2VLProcessor
pipe.processor = Qwen2VLProcessor.from_pretrained(processor_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@@ -386,8 +113,6 @@ class QwenImagePipeline(BasePipeline):
edit_rope_interpolation: bool = False,
# In-context control
context_image: Image.Image = None,
# FP8
enable_fp8_attention: bool = False,
# Tile
tiled: bool = False,
tile_size: int = 128,
@@ -411,7 +136,6 @@ class QwenImagePipeline(BasePipeline):
"inpaint_mask": inpaint_mask, "inpaint_blur_size": inpaint_blur_size, "inpaint_blur_sigma": inpaint_blur_sigma,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"enable_fp8_attention": enable_fp8_attention,
"num_inference_steps": num_inference_steps,
"blockwise_controlnet_inputs": blockwise_controlnet_inputs,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
@@ -427,16 +151,11 @@ class QwenImagePipeline(BasePipeline):
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id)
if cfg_scale != 1.0:
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# Scheduler
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
@@ -448,10 +167,41 @@ class QwenImagePipeline(BasePipeline):
return image
class QwenImageBlockwiseMultiControlNet(torch.nn.Module):
def __init__(self, models: list[QwenImageBlockWiseControlNet]):
super().__init__()
if not isinstance(models, list):
models = [models]
self.models = torch.nn.ModuleList(models)
for model in models:
if hasattr(model, "vram_management_enabled") and getattr(model, "vram_management_enabled"):
self.vram_management_enabled = True
def preprocess(self, controlnet_inputs: list[ControlNetInput], conditionings: list[torch.Tensor], **kwargs):
processed_conditionings = []
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
conditioning = rearrange(conditioning, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
model_output = self.models[controlnet_input.controlnet_id].process_controlnet_conditioning(conditioning)
processed_conditionings.append(model_output)
return processed_conditionings
def blockwise_forward(self, image, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, block_id, **kwargs):
res = 0
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)
if progress > controlnet_input.start + (1e-4) or progress < controlnet_input.end - (1e-4):
continue
model_output = self.models[controlnet_input.controlnet_id].blockwise_forward(image, conditioning, block_id)
res = res + model_output * controlnet_input.scale
return res
class QwenImageUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(input_params=("height", "width"))
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: QwenImagePipeline, height, width):
height, width = pipe.check_resize_height_width(height, width)
@@ -461,7 +211,10 @@ class QwenImageUnit_ShapeChecker(PipelineUnit):
class QwenImageUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(input_params=("height", "width", "seed", "rand_device"))
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
@@ -473,6 +226,7 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
@@ -494,6 +248,7 @@ class QwenImageUnit_Inpaint(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("inpaint_mask", "height", "width", "inpaint_blur_size", "inpaint_blur_sigma"),
output_params=("inpaint_mask",),
)
def process(self, pipe: QwenImagePipeline, inpaint_mask, height, width, inpaint_blur_size, inpaint_blur_sigma):
@@ -515,6 +270,7 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
input_params=("edit_image",),
output_params=("prompt_emb", "prompt_emb_mask"),
onload_model_names=("text_encoder",)
)
@@ -526,7 +282,6 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
return split_result
def calculate_dimensions(self, target_area, ratio):
import math
width = math.sqrt(target_area * ratio)
height = width / ratio
width = round(width / 32) * 32
@@ -573,6 +328,7 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
return split_hidden_states
def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:
pipe.load_models_to_device(self.onload_model_names)
if pipe.text_encoder is not None:
prompt = [prompt]
if edit_image is None:
@@ -595,6 +351,8 @@ class QwenImageUnit_EntityControl(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
input_params=("eligen_entity_prompts", "width", "height", "eligen_enable_on_negative", "cfg_scale"),
output_params=("entity_prompt_emb", "entity_masks", "entity_prompt_emb_mask"),
onload_model_names=("text_encoder",)
)
@@ -675,6 +433,7 @@ class QwenImageUnit_BlockwiseControlNet(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("blockwise_controlnet_inputs", "tiled", "tile_size", "tile_stride"),
output_params=("blockwise_controlnet_conditioning",),
onload_model_names=("vae",)
)
@@ -717,6 +476,7 @@ class QwenImageUnit_EditImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image", "tiled", "tile_size", "tile_stride", "edit_image_auto_resize"),
output_params=("edit_latents", "edit_image"),
onload_model_names=("vae",)
)
@@ -738,7 +498,7 @@ class QwenImageUnit_EditImageEmbedder(PipelineUnit):
def process(self, pipe: QwenImagePipeline, edit_image, tiled, tile_size, tile_stride, edit_image_auto_resize=False):
if edit_image is None:
return {}
pipe.load_models_to_device(['vae'])
pipe.load_models_to_device(self.onload_model_names)
if isinstance(edit_image, Image.Image):
resized_edit_image = self.edit_image_auto_resize(edit_image) if edit_image_auto_resize else edit_image
edit_image = pipe.preprocess_image(resized_edit_image).to(device=pipe.device, dtype=pipe.torch_dtype)
@@ -759,13 +519,14 @@ class QwenImageUnit_ContextImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("context_image", "height", "width", "tiled", "tile_size", "tile_stride"),
output_params=("context_latents",),
onload_model_names=("vae",)
)
def process(self, pipe: QwenImagePipeline, context_image, height, width, tiled, tile_size, tile_stride):
if context_image is None:
return {}
pipe.load_models_to_device(['vae'])
pipe.load_models_to_device(self.onload_model_names)
context_image = pipe.preprocess_image(context_image.resize((width, height))).to(device=pipe.device, dtype=pipe.torch_dtype)
context_latents = pipe.vae.encode(context_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return {"context_latents": context_latents}