Compare commits

...

5 Commits

Author SHA1 Message Date
mi804
b2d4bc8dd8 block wise controlnet 2025-08-12 13:10:47 +08:00
Artiprocher
c8ea3caf39 bugfix 2025-08-08 12:49:59 +08:00
Artiprocher
0d519ee08a bugfix 2025-08-08 12:47:04 +08:00
Artiprocher
6e13deb6de qwen-image controlnet 2025-08-08 11:29:23 +08:00
Zhongjie Duan
32cf5d32ce Qwen-Image FP8 (#761)
* support qwen-image-fp8

* refine README

* bugfix

* bugfix
2025-08-07 16:56:02 +08:00
16 changed files with 701 additions and 46 deletions

View File

@@ -75,6 +75,8 @@ from ..models.nexus_gen import NexusGenAutoregressiveModel
from ..models.qwen_image_dit import QwenImageDiT from ..models.qwen_image_dit import QwenImageDiT
from ..models.qwen_image_text_encoder import QwenImageTextEncoder from ..models.qwen_image_text_encoder import QwenImageTextEncoder
from ..models.qwen_image_vae import QwenImageVAE from ..models.qwen_image_vae import QwenImageVAE
from ..models.qwen_image_controlnet import QwenImageControlNet
from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
model_loader_configs = [ model_loader_configs = [
# These configs are provided for detecting model type automatically. # These configs are provided for detecting model type automatically.
@@ -167,6 +169,8 @@ model_loader_configs = [
(None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"), (None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"),
(None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"), (None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"),
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"), (None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
(None, "be2500a62936a43d5367a70ea001e25d", ["qwen_image_controlnet"], [QwenImageControlNet], "civitai"),
(None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
] ]
huggingface_model_loader_configs = [ huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically. # These configs are provided for detecting model type automatically.

View File

@@ -0,0 +1,159 @@
import torch
import torch.nn as nn
from .qwen_image_dit import QwenEmbedRope, QwenImageTransformerBlock
from ..vram_management import gradient_checkpoint_forward
from einops import rearrange
from .sd3_dit import TimestepEmbeddings, RMSNorm
class QwenImageControlNet(torch.nn.Module):
def __init__(
self,
num_layers: int = 60,
num_controlnet_layers: int = 6,
):
super().__init__()
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
self.txt_norm = RMSNorm(3584, eps=1e-6)
self.img_in = nn.Linear(64 * 2, 3072)
self.txt_in = nn.Linear(3584, 3072)
self.transformer_blocks = nn.ModuleList(
[
QwenImageTransformerBlock(
dim=3072,
num_attention_heads=24,
attention_head_dim=128,
)
for _ in range(num_controlnet_layers)
]
)
self.proj_out = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for i in range(num_layers)])
self.num_layers = num_layers
self.num_controlnet_layers = num_controlnet_layers
self.align_map = {i: i // (num_layers // num_controlnet_layers) for i in range(num_layers)}
def forward(
self,
latents=None,
timestep=None,
prompt_emb=None,
prompt_emb_mask=None,
height=None,
width=None,
controlnet_conditioning=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
controlnet_conditioning = rearrange(controlnet_conditioning, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
image = torch.concat([image, controlnet_conditioning], dim=-1)
image = self.img_in(image)
text = self.txt_in(self.txt_norm(prompt_emb))
conditioning = self.time_text_embed(timestep, image.dtype)
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
outputs = []
for block in self.transformer_blocks:
text, image = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
image=image,
text=text,
temb=conditioning,
image_rotary_emb=image_rotary_emb,
)
outputs.append(image)
outputs_aligned = [self.proj_out[i](outputs[self.align_map[i]]) for i in range(self.num_layers)]
return outputs_aligned
@staticmethod
def state_dict_converter():
return QwenImageControlNetStateDictConverter()
class QwenImageControlNetStateDictConverter():
def __init__(self):
pass
def from_civitai(self, state_dict):
return state_dict
class BlockWiseControlBlock(torch.nn.Module):
# [linear, gelu, linear]
def __init__(self, dim: int = 3072):
super().__init__()
self.x_rms = RMSNorm(dim, eps=1e-6)
self.y_rms = RMSNorm(dim, eps=1e-6)
self.input_proj = nn.Linear(dim, dim)
self.act = nn.GELU()
self.output_proj = nn.Linear(dim, dim)
def forward(self, x, y):
x, y = self.x_rms(x), self.y_rms(y)
x = self.input_proj(x + y)
x = self.act(x)
x = self.output_proj(x)
return x
def init_weights(self):
# zero initialize output_proj
nn.init.zeros_(self.output_proj.weight)
nn.init.zeros_(self.output_proj.bias)
class QwenImageBlockWiseControlNet(torch.nn.Module):
def __init__(
self,
num_layers: int = 60,
in_dim: int = 64,
dim: int = 3072,
):
super().__init__()
self.img_in = nn.Linear(in_dim, dim)
self.controlnet_blocks = nn.ModuleList(
[
BlockWiseControlBlock(dim)
for _ in range(num_layers)
]
)
def init_weight(self):
nn.init.zeros_(self.img_in.weight)
nn.init.zeros_(self.img_in.bias)
for block in self.controlnet_blocks:
block.init_weights()
def process_controlnet_conditioning(self, controlnet_conditioning):
return self.img_in(controlnet_conditioning)
def blockwise_forward(self, img, controlnet_conditioning, block_id):
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
@staticmethod
def state_dict_converter():
return QwenImageBlockWiseControlNetStateDictConverter()
class QwenImageBlockWiseControlNetStateDictConverter():
def __init__(self):
pass
def from_civitai(self, state_dict):
return state_dict

View File

@@ -1,10 +1,44 @@
import torch import torch, math
import torch.nn as nn import torch.nn as nn
from typing import Tuple, Optional, Union, List from typing import Tuple, Optional, Union, List
from einops import rearrange from einops import rearrange
from .sd3_dit import TimestepEmbeddings, RMSNorm from .sd3_dit import TimestepEmbeddings, RMSNorm
from .flux_dit import AdaLayerNorm from .flux_dit import AdaLayerNorm
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False):
if FLASH_ATTN_3_AVAILABLE and attention_mask is None:
if not enable_fp8_attention:
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x, tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
else:
origin_dtype = q.dtype
q_std, k_std, v_std = q.std(), k.std(), v.std()
q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn)
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1)))
if isinstance(x, tuple):
x = x[0]
x = x.to(origin_dtype) * v_std
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
return x
class ApproximateGELU(nn.Module): class ApproximateGELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True): def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
@@ -160,6 +194,7 @@ class QwenDoubleStreamAttention(nn.Module):
text: torch.FloatTensor, text: torch.FloatTensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
enable_fp8_attention: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image) img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text) txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
@@ -187,9 +222,7 @@ class QwenDoubleStreamAttention(nn.Module):
joint_k = torch.cat([txt_k, img_k], dim=2) joint_k = torch.cat([txt_k, img_k], dim=2)
joint_v = torch.cat([txt_v, img_v], dim=2) joint_v = torch.cat([txt_v, img_v], dim=2)
joint_attn_out = torch.nn.functional.scaled_dot_product_attention(joint_q, joint_k, joint_v, attn_mask=attention_mask) joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)
joint_attn_out = rearrange(joint_attn_out, 'b h s d -> b s (h d)').to(joint_q.dtype)
txt_attn_output = joint_attn_out[:, :seq_txt, :] txt_attn_output = joint_attn_out[:, :seq_txt, :]
img_attn_output = joint_attn_out[:, seq_txt:, :] img_attn_output = joint_attn_out[:, seq_txt:, :]
@@ -247,6 +280,7 @@ class QwenImageTransformerBlock(nn.Module):
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
enable_fp8_attention = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
@@ -263,6 +297,7 @@ class QwenImageTransformerBlock(nn.Module):
text=txt_modulated, text=txt_modulated,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask, attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention,
) )
image = image + img_gate * img_attn_out image = image + img_gate * img_attn_out
@@ -387,7 +422,7 @@ class QwenImageDiT(torch.nn.Module):
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)] img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (P Q C)", H=height//16, W=width//16, P=2, Q=2) image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
image = self.img_in(image) image = self.img_in(image)
text = self.txt_in(self.txt_norm(prompt_emb)) text = self.txt_in(self.txt_norm(prompt_emb))
@@ -406,7 +441,7 @@ class QwenImageDiT(torch.nn.Module):
image = self.norm_out(image, conditioning) image = self.norm_out(image, conditioning)
image = self.proj_out(image) image = self.proj_out(image)
latents = rearrange(image, "B (H W) (P Q C) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2) latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
return image return image
@staticmethod @staticmethod

View File

@@ -4,19 +4,55 @@ from typing import Union
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from einops import rearrange from einops import rearrange
import numpy as np
from ..models import ModelManager, load_state_dict from ..models import ModelManager, load_state_dict
from ..models.qwen_image_dit import QwenImageDiT from ..models.qwen_image_dit import QwenImageDiT
from ..models.qwen_image_text_encoder import QwenImageTextEncoder from ..models.qwen_image_text_encoder import QwenImageTextEncoder
from ..models.qwen_image_vae import QwenImageVAE from ..models.qwen_image_vae import QwenImageVAE
from ..models.qwen_image_controlnet import QwenImageControlNet, QwenImageBlockWiseControlNet
from ..schedulers import FlowMatchScheduler from ..schedulers import FlowMatchScheduler
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora import GeneralLoRALoader from ..lora import GeneralLoRALoader
from .flux_image_new import ControlNetInput
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
class QwenImageMultiControlNet(torch.nn.Module):
def __init__(self, models: list[QwenImageControlNet]):
super().__init__()
if not isinstance(models, list):
models = [models]
self.models = torch.nn.ModuleList(models)
def process_single_controlnet(self, controlnet_input: ControlNetInput, conditioning: torch.Tensor, **kwargs):
model = self.models[controlnet_input.controlnet_id]
res_stack = model(
controlnet_conditioning=conditioning,
processor_id=controlnet_input.processor_id,
**kwargs
)
res_stack = [res * controlnet_input.scale for res in res_stack]
return res_stack
def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, **kwargs):
res_stack = None
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)
if progress > controlnet_input.start + (1e-4) or progress < controlnet_input.end - (1e-4):
continue
res_stack_ = self.process_single_controlnet(controlnet_input, conditioning, **kwargs)
if res_stack is None:
res_stack = res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
return res_stack
class QwenImagePipeline(BasePipeline): class QwenImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
@@ -30,15 +66,17 @@ class QwenImagePipeline(BasePipeline):
self.text_encoder: QwenImageTextEncoder = None self.text_encoder: QwenImageTextEncoder = None
self.dit: QwenImageDiT = None self.dit: QwenImageDiT = None
self.vae: QwenImageVAE = None self.vae: QwenImageVAE = None
self.controlnet: QwenImageMultiControlNet = None
self.tokenizer: Qwen2Tokenizer = None self.tokenizer: Qwen2Tokenizer = None
self.unit_runner = PipelineUnitRunner() self.unit_runner = PipelineUnitRunner()
self.in_iteration_models = ("dit",) self.in_iteration_models = ("dit", "controlnet", "blockwise_controlnet")
self.units = [ self.units = [
QwenImageUnit_ShapeChecker(), QwenImageUnit_ShapeChecker(),
QwenImageUnit_NoiseInitializer(), QwenImageUnit_NoiseInitializer(),
QwenImageUnit_InputImageEmbedder(), QwenImageUnit_InputImageEmbedder(),
QwenImageUnit_PromptEmbedder(), QwenImageUnit_PromptEmbedder(),
QwenImageUnit_EntityControl(), QwenImageUnit_EntityControl(),
QwenImageUnit_ControlNet(),
] ]
self.model_fn = model_fn_qwen_image self.model_fn = model_fn_qwen_image
@@ -63,14 +101,12 @@ class QwenImagePipeline(BasePipeline):
return loss return loss
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5): def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):
self.vram_management_enabled = True self.vram_management_enabled = True
if num_persistent_param_in_dit is not None: if vram_limit is None:
vram_limit = None vram_limit = self.get_vram()
else: vram_limit = vram_limit - vram_buffer
if vram_limit is None:
vram_limit = self.get_vram()
vram_limit = vram_limit - vram_buffer
if self.text_encoder is not None: if self.text_encoder is not None:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm
dtype = next(iter(self.text_encoder.parameters())).dtype dtype = next(iter(self.text_encoder.parameters())).dtype
@@ -96,31 +132,54 @@ class QwenImagePipeline(BasePipeline):
from ..models.qwen_image_dit import RMSNorm from ..models.qwen_image_dit import RMSNorm
dtype = next(iter(self.dit.parameters())).dtype dtype = next(iter(self.dit.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device device = "cpu" if vram_limit is not None else self.device
enable_vram_management( if not enable_dit_fp8_computation:
self.dit, enable_vram_management(
module_map = { self.dit,
RMSNorm: AutoWrappedModule, module_map = {
torch.nn.Linear: AutoWrappedLinear, RMSNorm: AutoWrappedModule,
}, torch.nn.Linear: AutoWrappedLinear,
module_config = dict( },
offload_dtype=dtype, module_config = dict(
offload_device="cpu", offload_dtype=dtype,
onload_dtype=dtype, offload_device="cpu",
onload_device=device, onload_dtype=dtype,
computation_dtype=self.torch_dtype, onload_device=device,
computation_device=self.device, computation_dtype=self.torch_dtype,
), computation_device=self.device,
max_num_param=num_persistent_param_in_dit, ),
overflow_module_config = dict( vram_limit=vram_limit,
offload_dtype=dtype, )
offload_device="cpu", else:
onload_dtype=dtype, enable_vram_management(
onload_device="cpu", self.dit,
computation_dtype=self.torch_dtype, module_map = {
computation_device=self.device, RMSNorm: AutoWrappedModule,
), },
vram_limit=vram_limit, module_config = dict(
) offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
enable_vram_management(
self.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.vae is not None: if self.vae is not None:
from ..models.qwen_image_vae import QwenImageRMS_norm from ..models.qwen_image_vae import QwenImageRMS_norm
dtype = next(iter(self.vae.parameters())).dtype dtype = next(iter(self.vae.parameters())).dtype
@@ -166,6 +225,8 @@ class QwenImagePipeline(BasePipeline):
pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder") pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder")
pipe.dit = model_manager.fetch_model("qwen_image_dit") pipe.dit = model_manager.fetch_model("qwen_image_dit")
pipe.vae = model_manager.fetch_model("qwen_image_vae") pipe.vae = model_manager.fetch_model("qwen_image_vae")
pipe.controlnet = QwenImageMultiControlNet(model_manager.fetch_model("qwen_image_controlnet", index="all"))
pipe.blockwise_controlnet = model_manager.fetch_model("qwen_image_blockwise_controlnet")
if tokenizer_config is not None and pipe.text_encoder is not None: if tokenizer_config is not None and pipe.text_encoder is not None:
tokenizer_config.download_if_necessary() tokenizer_config.download_if_necessary()
from transformers import Qwen2Tokenizer from transformers import Qwen2Tokenizer
@@ -191,10 +252,14 @@ class QwenImagePipeline(BasePipeline):
rand_device: str = "cpu", rand_device: str = "cpu",
# Steps # Steps
num_inference_steps: int = 30, num_inference_steps: int = 30,
# ControlNet
controlnet_inputs: list[ControlNetInput] = None,
# EliGen # EliGen
eligen_entity_prompts: list[str] = None, eligen_entity_prompts: list[str] = None,
eligen_entity_masks: list[Image.Image] = None, eligen_entity_masks: list[Image.Image] = None,
eligen_enable_on_negative: bool = False, eligen_enable_on_negative: bool = False,
# FP8
enable_fp8_attention: bool = False,
# Tile # Tile
tiled: bool = False, tiled: bool = False,
tile_size: int = 128, tile_size: int = 128,
@@ -217,6 +282,9 @@ class QwenImagePipeline(BasePipeline):
"input_image": input_image, "denoising_strength": denoising_strength, "input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width, "height": height, "width": width,
"seed": seed, "rand_device": rand_device, "seed": seed, "rand_device": rand_device,
"enable_fp8_attention": enable_fp8_attention,
"num_inference_steps": num_inference_steps,
"controlnet_inputs": controlnet_inputs,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
} }
@@ -407,21 +475,88 @@ class QwenImageUnit_EntityControl(PipelineUnit):
return inputs_shared, inputs_posi, inputs_nega return inputs_shared, inputs_posi, inputs_nega
class QwenImageUnit_ControlNet(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def apply_controlnet_mask_on_latents(self, pipe, latents, mask):
mask = (pipe.preprocess_image(mask) + 1) / 2
mask = mask.mean(dim=1, keepdim=True)
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
latents = torch.concat([latents, mask], dim=1)
return latents
def apply_controlnet_mask_on_image(self, pipe, image, mask):
mask = mask.resize(image.size)
mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu()
image = np.array(image)
image[mask > 0] = 0
image = Image.fromarray(image)
return image
def process(self, pipe: QwenImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
if controlnet_inputs is None:
return {}
return_key = "blockwise_controlnet_conditioning" if pipe.blockwise_controlnet is not None else "controlnet_conditionings"
pipe.load_models_to_device(self.onload_model_names)
conditionings = []
for controlnet_input in controlnet_inputs:
image = controlnet_input.image
if controlnet_input.inpaint_mask is not None:
image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask)
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
image = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if controlnet_input.inpaint_mask is not None:
image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)
conditionings.append(image)
return {return_key: conditionings}
def model_fn_qwen_image( def model_fn_qwen_image(
dit: QwenImageDiT = None, dit: QwenImageDiT = None,
controlnet: QwenImageMultiControlNet = None,
blockwise_controlnet: QwenImageBlockWiseControlNet = None,
latents=None, latents=None,
timestep=None, timestep=None,
prompt_emb=None, prompt_emb=None,
prompt_emb_mask=None, prompt_emb_mask=None,
height=None, height=None,
width=None, width=None,
controlnet_inputs=None,
controlnet_conditionings=None,
blockwise_controlnet_conditioning=None,
progress_id=0,
num_inference_steps=1,
entity_prompt_emb=None, entity_prompt_emb=None,
entity_prompt_emb_mask=None, entity_prompt_emb_mask=None,
entity_masks=None, entity_masks=None,
enable_fp8_attention=False,
use_gradient_checkpointing=False, use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False, use_gradient_checkpointing_offload=False,
**kwargs **kwargs
): ):
# ControlNet
if controlnet_conditionings is not None:
controlnet_extra_kwargs = {
"latents": latents,
"timestep": timestep,
"prompt_emb": prompt_emb,
"prompt_emb_mask": prompt_emb_mask,
"height": height,
"width": width,
"use_gradient_checkpointing": use_gradient_checkpointing,
"use_gradient_checkpointing_offload": use_gradient_checkpointing_offload,
}
res_stack = controlnet(
controlnet_conditionings, controlnet_inputs, progress_id, num_inference_steps,
**controlnet_extra_kwargs
)
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)] img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
timestep = timestep / 1000 timestep = timestep / 1000
@@ -441,7 +576,14 @@ def model_fn_qwen_image(
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device) image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
attention_mask = None attention_mask = None
for block in dit.transformer_blocks: if blockwise_controlnet_conditioning is not None:
blockwise_controlnet_conditioning = rearrange(
blockwise_controlnet_conditioning[0], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2
)
blockwise_controlnet_conditioning = blockwise_controlnet.process_controlnet_conditioning(blockwise_controlnet_conditioning)
# blockwise_controlnet_conditioning =
for block_id, block in enumerate(dit.transformer_blocks):
text, image = gradient_checkpoint_forward( text, image = gradient_checkpoint_forward(
block, block,
use_gradient_checkpointing, use_gradient_checkpointing,
@@ -451,8 +593,13 @@ def model_fn_qwen_image(
temb=conditioning, temb=conditioning,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask, attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention,
) )
if blockwise_controlnet is not None:
image = image + blockwise_controlnet.blockwise_forward(image, blockwise_controlnet_conditioning, block_id)
if controlnet_conditionings is not None:
image = image + res_stack[block_id]
image = dit.norm_out(image, conditioning) image = dit.norm_out(image, conditioning)
image = dit.proj_out(image) image = dit.proj_out(image)

View File

@@ -110,8 +110,48 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
self.lora_A_weights = [] self.lora_A_weights = []
self.lora_B_weights = [] self.lora_B_weights = []
self.lora_merger = None self.lora_merger = None
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
def fp8_linear(
self,
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
device = input.device
origin_dtype = input.dtype
origin_shape = input.shape
input = input.reshape(-1, origin_shape[-1])
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
fp8_max = 448.0
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
# To avoid overflow and ensure numerical compatibility during FP8 computation,
# we scale down the input by 2.0 in advance.
# This scaling will be compensated later during the final result scaling.
if self.computation_dtype == torch.float8_e4m3fnuz:
fp8_max = fp8_max / 2.0
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
input = input / (scale_a + 1e-8)
input = input.to(self.computation_dtype)
weight = weight.to(self.computation_dtype)
bias = bias.to(torch.bfloat16)
result = torch._scaled_mm(
input,
weight.T,
scale_a=scale_a,
scale_b=scale_b.T,
bias=bias,
out_dtype=origin_dtype,
)
new_shape = origin_shape[:-1] + result.shape[-1:]
result = result.reshape(new_shape)
return result
def forward(self, x, *args, **kwargs): def forward(self, x, *args, **kwargs):
# VRAM management
if self.state == 2: if self.state == 2:
weight, bias = self.weight, self.bias weight, bias = self.weight, self.bias
else: else:
@@ -123,8 +163,14 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
else: else:
weight = cast_to(self.weight, self.computation_dtype, self.computation_device) weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
out = torch.nn.functional.linear(x, weight, bias)
# Linear forward
if self.enable_fp8:
out = self.fp8_linear(x, weight, bias)
else:
out = torch.nn.functional.linear(x, weight, bias)
# LoRA
if len(self.lora_A_weights) == 0: if len(self.lora_A_weights) == 0:
# No LoRA # No LoRA
return out return out

View File

@@ -164,6 +164,7 @@ After enabling VRAM management, the framework will automatically choose a memory
* `vram_limit`: VRAM usage limit in GB. By default, it uses all free VRAM on the device. Note that this is not a strict limit. If the set limit is too low but actual free VRAM is enough, the model will run with minimal VRAM use. Set it to 0 for the smallest possible VRAM use. * `vram_limit`: VRAM usage limit in GB. By default, it uses all free VRAM on the device. Note that this is not a strict limit. If the set limit is too low but actual free VRAM is enough, the model will run with minimal VRAM use. Set it to 0 for the smallest possible VRAM use.
* `vram_buffer`: VRAM buffer size in GB. Default is 0.5GB. A buffer is needed because large network layers may use more VRAM than expected during loading. The best value is the VRAM size of the largest model layer. * `vram_buffer`: VRAM buffer size in GB. Default is 0.5GB. A buffer is needed because large network layers may use more VRAM than expected during loading. The best value is the VRAM size of the largest model layer.
* `num_persistent_param_in_dit`: Number of parameters to keep in VRAM in the DiT model. Default is no limit. This option will be removed in the future. Do not rely on it. * `num_persistent_param_in_dit`: Number of parameters to keep in VRAM in the DiT model. Default is no limit. This option will be removed in the future. Do not rely on it.
* `enable_dit_fp8_computation`: Whether to enable FP8 computation in the DiT model. This is only applicable to GPUs that support FP8 operations (e.g., H200, etc.). Disabled by default.
</details> </details>
@@ -172,7 +173,11 @@ After enabling VRAM management, the framework will automatically choose a memory
<summary>Inference Acceleration</summary> <summary>Inference Acceleration</summary>
Inference acceleration for Qwen-Image is under development. Please stay tuned! * FP8 Quantization: Choose the appropriate quantization method based on your hardware and requirements.
* GPUs that do not support FP8 computation (e.g., A100, 4090, etc.): FP8 quantization will only reduce VRAM usage without speeding up inference. Code: [./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py)
* GPUs that support FP8 operations (e.g., H200, etc.): Please install [Flash Attention 3](https://github.com/Dao-AILab/flash-attention). Otherwise, FP8 acceleration will only apply to Linear layers.
* Faster inference but higher VRAM usage: Use [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py)
* Slightly slower inference but lower VRAM usage: Use [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py)
</details> </details>

View File

@@ -164,6 +164,7 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在
* `vram_limit`: 显存占用量限制GB默认占用设备上的剩余显存。注意这不是一个绝对限制当设置的显存不足以支持模型进行推理但实际可用显存足够时将会以最小化显存占用的形式进行推理。将其设置为0时将会实现理论最小显存占用。 * `vram_limit`: 显存占用量限制GB默认占用设备上的剩余显存。注意这不是一个绝对限制当设置的显存不足以支持模型进行推理但实际可用显存足够时将会以最小化显存占用的形式进行推理。将其设置为0时将会实现理论最小显存占用。
* `vram_buffer`: 显存缓冲区大小GB默认为 0.5GB。由于部分较大的神经网络层在 onload 阶段会不可控地占用更多显存,因此一个显存缓冲区是必要的,理论上的最优值为模型中最大的层所占的显存。 * `vram_buffer`: 显存缓冲区大小GB默认为 0.5GB。由于部分较大的神经网络层在 onload 阶段会不可控地占用更多显存,因此一个显存缓冲区是必要的,理论上的最优值为模型中最大的层所占的显存。
* `num_persistent_param_in_dit`: DiT 模型中常驻显存的参数数量(个),默认为无限制。我们将会在未来删除这个参数,请不要依赖这个参数。 * `num_persistent_param_in_dit`: DiT 模型中常驻显存的参数数量(个),默认为无限制。我们将会在未来删除这个参数,请不要依赖这个参数。
* `enable_dit_fp8_computation`: 是否启用 DiT 模型中的 FP8 计算,仅适用于支持 FP8 运算的 GPU例如 H200 等),默认不启用。
</details> </details>
@@ -172,7 +173,11 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在
<summary>推理加速</summary> <summary>推理加速</summary>
Qwen-Image 的推理加速技术正在开发中,敬请期待! * FP8 量化:根据您的硬件与需求,请选择合适的量化方式
* GPU 不支持 FP8 计算(例如 A100、4090 等FP8 量化仅能降低显存占用,无法加速,代码:[./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py)
* GPU 支持 FP8 运算(例如 H200 等):请安装 [Flash Attention 3](https://github.com/Dao-AILab/flash-attention),否则 FP8 加速仅对 Linear 层生效
* 更快的速度,但更大的显存:请使用 [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py)
* 稍慢的速度,但更小的显存:请使用 [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py)
</details> </details>

View File

@@ -0,0 +1,18 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
pipe.enable_vram_management(enable_dit_fp8_computation=True)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40, enable_fp8_attention=True)
image.save("image.jpg")

View File

@@ -0,0 +1,51 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.models.qwen_image_dit import RMSNorm
from diffsynth.vram_management.layers import enable_vram_management, AutoWrappedLinear, AutoWrappedModule
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
enable_vram_management(
pipe.dit,
module_map = {
RMSNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=torch.bfloat16,
offload_device="cuda",
onload_dtype=torch.bfloat16,
onload_device="cuda",
computation_dtype=torch.bfloat16,
computation_device="cuda",
),
vram_limit=None,
)
enable_vram_management(
pipe.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=torch.float8_e4m3fn,
offload_device="cuda",
onload_dtype=torch.float8_e4m3fn,
onload_device="cuda",
computation_dtype=torch.float8_e4m3fn,
computation_device="cuda",
),
vram_limit=None,
)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40, enable_fp8_attention=True)
image.save("image.jpg")

View File

@@ -0,0 +1,36 @@
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config.yaml examples/qwen_image/model_training/train.py \
--dataset_base_path "" \
--dataset_metadata_path data/t2i_dataset_annotations/blip3o/blip3o_control_images_train_for_diffsynth.jsonl \
--data_file_keys "image,controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_paths '[
[
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors"
],
[
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
],
"models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors",
"models/DiffSynth-Studio/BlockWiseControlnet/model_init.safetensors"
]' \
--learning_rate 1e-3 \
--num_epochs 1000000 \
--remove_prefix_in_ckpt "pipe.blockwise_controlnet." \
--output_path "./models/train/Qwen-Image-BlockWiseControlNet_full_lr1e-3_wd1e-6" \
--trainable_models "blockwise_controlnet" \
--extra_inputs "controlnet_image" \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--save_steps 2000

View File

@@ -0,0 +1,35 @@
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \
--data_file_keys "image,controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 80000 \
--model_paths '[
[
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors"
],
[
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
],
"models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors",
"models/controlnet.safetensors"
]' \
--learning_rate 1e-5 \
--num_epochs 1000000 \
--remove_prefix_in_ckpt "pipe.controlnet.models.0." \
--output_path "./models/train/Qwen-Image-ControlNet_full" \
--trainable_models "controlnet" \
--extra_inputs "controlnet_image" \
--use_gradient_checkpointing \
--save_steps 100

View File

@@ -0,0 +1,22 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -0,0 +1,13 @@
# This script is for initializing a Qwen-Image-ControlNet
from diffsynth import load_state_dict, hash_state_dict_keys
from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet
import torch
from safetensors.torch import save_file
controlnet = QwenImageBlockWiseControlNet().to(dtype=torch.bfloat16, device="cuda")
controlnet.init_weight()
state_dict_controlnet = controlnet.state_dict()
print(hash_state_dict_keys(state_dict_controlnet))
save_file(state_dict_controlnet, "models/DiffSynth-Studio/BlockWiseControlnet/model_init.safetensors")

View File

@@ -0,0 +1,34 @@
# This script is for initializing a Qwen-Image-ControlNet
from diffsynth import load_state_dict, hash_state_dict_keys
from diffsynth.pipelines.qwen_image import QwenImageControlNet
import torch
from safetensors.torch import save_file
state_dict_dit = {}
for i in range(1, 10):
state_dict_dit.update(load_state_dict(f"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-0000{i}-of-00009.safetensors", torch_dtype=torch.bfloat16, device="cuda"))
controlnet = QwenImageControlNet().to(dtype=torch.bfloat16, device="cuda")
state_dict_controlnet = controlnet.state_dict()
state_dict_init = {}
for k in state_dict_controlnet:
if k in state_dict_dit:
if state_dict_dit[k].shape == state_dict_controlnet[k].shape:
state_dict_init[k] = state_dict_dit[k]
elif k == "img_in.weight":
state_dict_init[k] = torch.concat(
[
state_dict_dit[k],
state_dict_dit[k],
],
dim=-1
)
else:
print("Zero Initialized:", k)
state_dict_init[k] = torch.zeros_like(state_dict_controlnet[k])
controlnet.load_state_dict(state_dict_init)
print(hash_state_dict_keys(state_dict_init))
save_file(state_dict_init, "models/controlnet.safetensors")

View File

@@ -1,5 +1,5 @@
import torch, os, json import torch, os, json
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser
from diffsynth.models.lora import QwenImageLoRAConverter from diffsynth.models.lora import QwenImageLoRAConverter
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -73,8 +73,15 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
} }
# Extra inputs # Extra inputs
controlnet_input = {}
for extra_input in self.extra_inputs: for extra_input in self.extra_inputs:
inputs_shared[extra_input] = data[extra_input] inputs_shared[extra_input] = data[extra_input]
if extra_input.startswith("controlnet_"):
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
else:
inputs_shared[extra_input] = data[extra_input]
if len(controlnet_input) > 0:
inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)]
# Pipeline units will automatically process the input parameters. # Pipeline units will automatically process the input parameters.
for unit in self.pipe.units: for unit in self.pipe.units:
@@ -111,7 +118,7 @@ if __name__ == "__main__":
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x, state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
) )
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate) optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=0.000001)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
launch_training_task( launch_training_task(
dataset, model, model_logger, optimizer, scheduler, dataset, model, model_logger, optimizer, scheduler,

View File

@@ -0,0 +1,38 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
from diffsynth import load_state_dict
import torch
from PIL import Image
from diffsynth.controlnets.processors import Annotator
import os
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
ModelConfig(path="models/DiffSynth-Studio/BlockWiseControlnet/model_init.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
state_dict = load_state_dict("models/train/Qwen-Image-BlockWiseControlNet_full_lr1e-3_wd1e-6/step-26000.safetensors")
pipe.blockwise_controlnet.load_state_dict(state_dict)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = Image.open("test_image.jpg").convert("RGB").resize((1024, 1024))
canny_image = Annotator("canny")(image)
canny_image.save("canny_image_test.jpg")
controlnet_input = ControlNetInput(
image=canny_image,
scale=1.0,
processor_id="canny",
)
for seed in range(100, 200):
image = pipe(prompt, seed=seed, height=1024, width=1024, controlnet_inputs=[controlnet_input], num_inference_steps=30, cfg_scale=4.0)
image.save(f"test_image_controlnet_step2k_1_{seed}.jpg")