Compare commits

...

4 Commits

Author SHA1 Message Date
mi804
b2d4bc8dd8 block wise controlnet 2025-08-12 13:10:47 +08:00
Artiprocher
c8ea3caf39 bugfix 2025-08-08 12:49:59 +08:00
Artiprocher
0d519ee08a bugfix 2025-08-08 12:47:04 +08:00
Artiprocher
6e13deb6de qwen-image controlnet 2025-08-08 11:29:23 +08:00
11 changed files with 476 additions and 7 deletions

View File

@@ -75,6 +75,8 @@ from ..models.nexus_gen import NexusGenAutoregressiveModel
from ..models.qwen_image_dit import QwenImageDiT from ..models.qwen_image_dit import QwenImageDiT
from ..models.qwen_image_text_encoder import QwenImageTextEncoder from ..models.qwen_image_text_encoder import QwenImageTextEncoder
from ..models.qwen_image_vae import QwenImageVAE from ..models.qwen_image_vae import QwenImageVAE
from ..models.qwen_image_controlnet import QwenImageControlNet
from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
model_loader_configs = [ model_loader_configs = [
# These configs are provided for detecting model type automatically. # These configs are provided for detecting model type automatically.
@@ -167,6 +169,8 @@ model_loader_configs = [
(None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"), (None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"),
(None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"), (None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"),
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"), (None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
(None, "be2500a62936a43d5367a70ea001e25d", ["qwen_image_controlnet"], [QwenImageControlNet], "civitai"),
(None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
] ]
huggingface_model_loader_configs = [ huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically. # These configs are provided for detecting model type automatically.

View File

@@ -0,0 +1,159 @@
import torch
import torch.nn as nn
from .qwen_image_dit import QwenEmbedRope, QwenImageTransformerBlock
from ..vram_management import gradient_checkpoint_forward
from einops import rearrange
from .sd3_dit import TimestepEmbeddings, RMSNorm
class QwenImageControlNet(torch.nn.Module):
def __init__(
self,
num_layers: int = 60,
num_controlnet_layers: int = 6,
):
super().__init__()
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
self.txt_norm = RMSNorm(3584, eps=1e-6)
self.img_in = nn.Linear(64 * 2, 3072)
self.txt_in = nn.Linear(3584, 3072)
self.transformer_blocks = nn.ModuleList(
[
QwenImageTransformerBlock(
dim=3072,
num_attention_heads=24,
attention_head_dim=128,
)
for _ in range(num_controlnet_layers)
]
)
self.proj_out = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for i in range(num_layers)])
self.num_layers = num_layers
self.num_controlnet_layers = num_controlnet_layers
self.align_map = {i: i // (num_layers // num_controlnet_layers) for i in range(num_layers)}
def forward(
self,
latents=None,
timestep=None,
prompt_emb=None,
prompt_emb_mask=None,
height=None,
width=None,
controlnet_conditioning=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
controlnet_conditioning = rearrange(controlnet_conditioning, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
image = torch.concat([image, controlnet_conditioning], dim=-1)
image = self.img_in(image)
text = self.txt_in(self.txt_norm(prompt_emb))
conditioning = self.time_text_embed(timestep, image.dtype)
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
outputs = []
for block in self.transformer_blocks:
text, image = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
image=image,
text=text,
temb=conditioning,
image_rotary_emb=image_rotary_emb,
)
outputs.append(image)
outputs_aligned = [self.proj_out[i](outputs[self.align_map[i]]) for i in range(self.num_layers)]
return outputs_aligned
@staticmethod
def state_dict_converter():
return QwenImageControlNetStateDictConverter()
class QwenImageControlNetStateDictConverter():
def __init__(self):
pass
def from_civitai(self, state_dict):
return state_dict
class BlockWiseControlBlock(torch.nn.Module):
# [linear, gelu, linear]
def __init__(self, dim: int = 3072):
super().__init__()
self.x_rms = RMSNorm(dim, eps=1e-6)
self.y_rms = RMSNorm(dim, eps=1e-6)
self.input_proj = nn.Linear(dim, dim)
self.act = nn.GELU()
self.output_proj = nn.Linear(dim, dim)
def forward(self, x, y):
x, y = self.x_rms(x), self.y_rms(y)
x = self.input_proj(x + y)
x = self.act(x)
x = self.output_proj(x)
return x
def init_weights(self):
# zero initialize output_proj
nn.init.zeros_(self.output_proj.weight)
nn.init.zeros_(self.output_proj.bias)
class QwenImageBlockWiseControlNet(torch.nn.Module):
def __init__(
self,
num_layers: int = 60,
in_dim: int = 64,
dim: int = 3072,
):
super().__init__()
self.img_in = nn.Linear(in_dim, dim)
self.controlnet_blocks = nn.ModuleList(
[
BlockWiseControlBlock(dim)
for _ in range(num_layers)
]
)
def init_weight(self):
nn.init.zeros_(self.img_in.weight)
nn.init.zeros_(self.img_in.bias)
for block in self.controlnet_blocks:
block.init_weights()
def process_controlnet_conditioning(self, controlnet_conditioning):
return self.img_in(controlnet_conditioning)
def blockwise_forward(self, img, controlnet_conditioning, block_id):
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
@staticmethod
def state_dict_converter():
return QwenImageBlockWiseControlNetStateDictConverter()
class QwenImageBlockWiseControlNetStateDictConverter():
def __init__(self):
pass
def from_civitai(self, state_dict):
return state_dict

View File

@@ -422,7 +422,7 @@ class QwenImageDiT(torch.nn.Module):
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)] img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (P Q C)", H=height//16, W=width//16, P=2, Q=2) image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
image = self.img_in(image) image = self.img_in(image)
text = self.txt_in(self.txt_norm(prompt_emb)) text = self.txt_in(self.txt_norm(prompt_emb))
@@ -441,7 +441,7 @@ class QwenImageDiT(torch.nn.Module):
image = self.norm_out(image, conditioning) image = self.norm_out(image, conditioning)
image = self.proj_out(image) image = self.proj_out(image)
latents = rearrange(image, "B (H W) (P Q C) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2) latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
return image return image
@staticmethod @staticmethod

View File

@@ -4,19 +4,55 @@ from typing import Union
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from einops import rearrange from einops import rearrange
import numpy as np
from ..models import ModelManager, load_state_dict from ..models import ModelManager, load_state_dict
from ..models.qwen_image_dit import QwenImageDiT from ..models.qwen_image_dit import QwenImageDiT
from ..models.qwen_image_text_encoder import QwenImageTextEncoder from ..models.qwen_image_text_encoder import QwenImageTextEncoder
from ..models.qwen_image_vae import QwenImageVAE from ..models.qwen_image_vae import QwenImageVAE
from ..models.qwen_image_controlnet import QwenImageControlNet, QwenImageBlockWiseControlNet
from ..schedulers import FlowMatchScheduler from ..schedulers import FlowMatchScheduler
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora import GeneralLoRALoader from ..lora import GeneralLoRALoader
from .flux_image_new import ControlNetInput
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
class QwenImageMultiControlNet(torch.nn.Module):
def __init__(self, models: list[QwenImageControlNet]):
super().__init__()
if not isinstance(models, list):
models = [models]
self.models = torch.nn.ModuleList(models)
def process_single_controlnet(self, controlnet_input: ControlNetInput, conditioning: torch.Tensor, **kwargs):
model = self.models[controlnet_input.controlnet_id]
res_stack = model(
controlnet_conditioning=conditioning,
processor_id=controlnet_input.processor_id,
**kwargs
)
res_stack = [res * controlnet_input.scale for res in res_stack]
return res_stack
def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, **kwargs):
res_stack = None
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)
if progress > controlnet_input.start + (1e-4) or progress < controlnet_input.end - (1e-4):
continue
res_stack_ = self.process_single_controlnet(controlnet_input, conditioning, **kwargs)
if res_stack is None:
res_stack = res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
return res_stack
class QwenImagePipeline(BasePipeline): class QwenImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
@@ -30,15 +66,17 @@ class QwenImagePipeline(BasePipeline):
self.text_encoder: QwenImageTextEncoder = None self.text_encoder: QwenImageTextEncoder = None
self.dit: QwenImageDiT = None self.dit: QwenImageDiT = None
self.vae: QwenImageVAE = None self.vae: QwenImageVAE = None
self.controlnet: QwenImageMultiControlNet = None
self.tokenizer: Qwen2Tokenizer = None self.tokenizer: Qwen2Tokenizer = None
self.unit_runner = PipelineUnitRunner() self.unit_runner = PipelineUnitRunner()
self.in_iteration_models = ("dit",) self.in_iteration_models = ("dit", "controlnet", "blockwise_controlnet")
self.units = [ self.units = [
QwenImageUnit_ShapeChecker(), QwenImageUnit_ShapeChecker(),
QwenImageUnit_NoiseInitializer(), QwenImageUnit_NoiseInitializer(),
QwenImageUnit_InputImageEmbedder(), QwenImageUnit_InputImageEmbedder(),
QwenImageUnit_PromptEmbedder(), QwenImageUnit_PromptEmbedder(),
QwenImageUnit_EntityControl(), QwenImageUnit_EntityControl(),
QwenImageUnit_ControlNet(),
] ]
self.model_fn = model_fn_qwen_image self.model_fn = model_fn_qwen_image
@@ -187,6 +225,8 @@ class QwenImagePipeline(BasePipeline):
pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder") pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder")
pipe.dit = model_manager.fetch_model("qwen_image_dit") pipe.dit = model_manager.fetch_model("qwen_image_dit")
pipe.vae = model_manager.fetch_model("qwen_image_vae") pipe.vae = model_manager.fetch_model("qwen_image_vae")
pipe.controlnet = QwenImageMultiControlNet(model_manager.fetch_model("qwen_image_controlnet", index="all"))
pipe.blockwise_controlnet = model_manager.fetch_model("qwen_image_blockwise_controlnet")
if tokenizer_config is not None and pipe.text_encoder is not None: if tokenizer_config is not None and pipe.text_encoder is not None:
tokenizer_config.download_if_necessary() tokenizer_config.download_if_necessary()
from transformers import Qwen2Tokenizer from transformers import Qwen2Tokenizer
@@ -212,6 +252,8 @@ class QwenImagePipeline(BasePipeline):
rand_device: str = "cpu", rand_device: str = "cpu",
# Steps # Steps
num_inference_steps: int = 30, num_inference_steps: int = 30,
# ControlNet
controlnet_inputs: list[ControlNetInput] = None,
# EliGen # EliGen
eligen_entity_prompts: list[str] = None, eligen_entity_prompts: list[str] = None,
eligen_entity_masks: list[Image.Image] = None, eligen_entity_masks: list[Image.Image] = None,
@@ -241,6 +283,8 @@ class QwenImagePipeline(BasePipeline):
"height": height, "width": width, "height": height, "width": width,
"seed": seed, "rand_device": rand_device, "seed": seed, "rand_device": rand_device,
"enable_fp8_attention": enable_fp8_attention, "enable_fp8_attention": enable_fp8_attention,
"num_inference_steps": num_inference_steps,
"controlnet_inputs": controlnet_inputs,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
} }
@@ -431,14 +475,63 @@ class QwenImageUnit_EntityControl(PipelineUnit):
return inputs_shared, inputs_posi, inputs_nega return inputs_shared, inputs_posi, inputs_nega
class QwenImageUnit_ControlNet(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def apply_controlnet_mask_on_latents(self, pipe, latents, mask):
mask = (pipe.preprocess_image(mask) + 1) / 2
mask = mask.mean(dim=1, keepdim=True)
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
latents = torch.concat([latents, mask], dim=1)
return latents
def apply_controlnet_mask_on_image(self, pipe, image, mask):
mask = mask.resize(image.size)
mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu()
image = np.array(image)
image[mask > 0] = 0
image = Image.fromarray(image)
return image
def process(self, pipe: QwenImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
if controlnet_inputs is None:
return {}
return_key = "blockwise_controlnet_conditioning" if pipe.blockwise_controlnet is not None else "controlnet_conditionings"
pipe.load_models_to_device(self.onload_model_names)
conditionings = []
for controlnet_input in controlnet_inputs:
image = controlnet_input.image
if controlnet_input.inpaint_mask is not None:
image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask)
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
image = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if controlnet_input.inpaint_mask is not None:
image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)
conditionings.append(image)
return {return_key: conditionings}
def model_fn_qwen_image( def model_fn_qwen_image(
dit: QwenImageDiT = None, dit: QwenImageDiT = None,
controlnet: QwenImageMultiControlNet = None,
blockwise_controlnet: QwenImageBlockWiseControlNet = None,
latents=None, latents=None,
timestep=None, timestep=None,
prompt_emb=None, prompt_emb=None,
prompt_emb_mask=None, prompt_emb_mask=None,
height=None, height=None,
width=None, width=None,
controlnet_inputs=None,
controlnet_conditionings=None,
blockwise_controlnet_conditioning=None,
progress_id=0,
num_inference_steps=1,
entity_prompt_emb=None, entity_prompt_emb=None,
entity_prompt_emb_mask=None, entity_prompt_emb_mask=None,
entity_masks=None, entity_masks=None,
@@ -447,6 +540,23 @@ def model_fn_qwen_image(
use_gradient_checkpointing_offload=False, use_gradient_checkpointing_offload=False,
**kwargs **kwargs
): ):
# ControlNet
if controlnet_conditionings is not None:
controlnet_extra_kwargs = {
"latents": latents,
"timestep": timestep,
"prompt_emb": prompt_emb,
"prompt_emb_mask": prompt_emb_mask,
"height": height,
"width": width,
"use_gradient_checkpointing": use_gradient_checkpointing,
"use_gradient_checkpointing_offload": use_gradient_checkpointing_offload,
}
res_stack = controlnet(
controlnet_conditionings, controlnet_inputs, progress_id, num_inference_steps,
**controlnet_extra_kwargs
)
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)] img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist() txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
timestep = timestep / 1000 timestep = timestep / 1000
@@ -466,7 +576,14 @@ def model_fn_qwen_image(
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device) image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
attention_mask = None attention_mask = None
for block in dit.transformer_blocks: if blockwise_controlnet_conditioning is not None:
blockwise_controlnet_conditioning = rearrange(
blockwise_controlnet_conditioning[0], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2
)
blockwise_controlnet_conditioning = blockwise_controlnet.process_controlnet_conditioning(blockwise_controlnet_conditioning)
# blockwise_controlnet_conditioning =
for block_id, block in enumerate(dit.transformer_blocks):
text, image = gradient_checkpoint_forward( text, image = gradient_checkpoint_forward(
block, block,
use_gradient_checkpointing, use_gradient_checkpointing,
@@ -478,6 +595,10 @@ def model_fn_qwen_image(
attention_mask=attention_mask, attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention, enable_fp8_attention=enable_fp8_attention,
) )
if blockwise_controlnet is not None:
image = image + blockwise_controlnet.blockwise_forward(image, blockwise_controlnet_conditioning, block_id)
if controlnet_conditionings is not None:
image = image + res_stack[block_id]
image = dit.norm_out(image, conditioning) image = dit.norm_out(image, conditioning)
image = dit.proj_out(image) image = dit.proj_out(image)

View File

@@ -0,0 +1,36 @@
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config.yaml examples/qwen_image/model_training/train.py \
--dataset_base_path "" \
--dataset_metadata_path data/t2i_dataset_annotations/blip3o/blip3o_control_images_train_for_diffsynth.jsonl \
--data_file_keys "image,controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_paths '[
[
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors"
],
[
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
],
"models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors",
"models/DiffSynth-Studio/BlockWiseControlnet/model_init.safetensors"
]' \
--learning_rate 1e-3 \
--num_epochs 1000000 \
--remove_prefix_in_ckpt "pipe.blockwise_controlnet." \
--output_path "./models/train/Qwen-Image-BlockWiseControlNet_full_lr1e-3_wd1e-6" \
--trainable_models "blockwise_controlnet" \
--extra_inputs "controlnet_image" \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--save_steps 2000

View File

@@ -0,0 +1,35 @@
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \
--data_file_keys "image,controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 80000 \
--model_paths '[
[
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors",
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors"
],
[
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
],
"models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors",
"models/controlnet.safetensors"
]' \
--learning_rate 1e-5 \
--num_epochs 1000000 \
--remove_prefix_in_ckpt "pipe.controlnet.models.0." \
--output_path "./models/train/Qwen-Image-ControlNet_full" \
--trainable_models "controlnet" \
--extra_inputs "controlnet_image" \
--use_gradient_checkpointing \
--save_steps 100

View File

@@ -0,0 +1,22 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -0,0 +1,13 @@
# This script is for initializing a Qwen-Image-ControlNet
from diffsynth import load_state_dict, hash_state_dict_keys
from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet
import torch
from safetensors.torch import save_file
controlnet = QwenImageBlockWiseControlNet().to(dtype=torch.bfloat16, device="cuda")
controlnet.init_weight()
state_dict_controlnet = controlnet.state_dict()
print(hash_state_dict_keys(state_dict_controlnet))
save_file(state_dict_controlnet, "models/DiffSynth-Studio/BlockWiseControlnet/model_init.safetensors")

View File

@@ -0,0 +1,34 @@
# This script is for initializing a Qwen-Image-ControlNet
from diffsynth import load_state_dict, hash_state_dict_keys
from diffsynth.pipelines.qwen_image import QwenImageControlNet
import torch
from safetensors.torch import save_file
state_dict_dit = {}
for i in range(1, 10):
state_dict_dit.update(load_state_dict(f"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-0000{i}-of-00009.safetensors", torch_dtype=torch.bfloat16, device="cuda"))
controlnet = QwenImageControlNet().to(dtype=torch.bfloat16, device="cuda")
state_dict_controlnet = controlnet.state_dict()
state_dict_init = {}
for k in state_dict_controlnet:
if k in state_dict_dit:
if state_dict_dit[k].shape == state_dict_controlnet[k].shape:
state_dict_init[k] = state_dict_dit[k]
elif k == "img_in.weight":
state_dict_init[k] = torch.concat(
[
state_dict_dit[k],
state_dict_dit[k],
],
dim=-1
)
else:
print("Zero Initialized:", k)
state_dict_init[k] = torch.zeros_like(state_dict_controlnet[k])
controlnet.load_state_dict(state_dict_init)
print(hash_state_dict_keys(state_dict_init))
save_file(state_dict_init, "models/controlnet.safetensors")

View File

@@ -1,5 +1,5 @@
import torch, os, json import torch, os, json
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser
from diffsynth.models.lora import QwenImageLoRAConverter from diffsynth.models.lora import QwenImageLoRAConverter
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -73,8 +73,15 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
} }
# Extra inputs # Extra inputs
controlnet_input = {}
for extra_input in self.extra_inputs: for extra_input in self.extra_inputs:
inputs_shared[extra_input] = data[extra_input] inputs_shared[extra_input] = data[extra_input]
if extra_input.startswith("controlnet_"):
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
else:
inputs_shared[extra_input] = data[extra_input]
if len(controlnet_input) > 0:
inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)]
# Pipeline units will automatically process the input parameters. # Pipeline units will automatically process the input parameters.
for unit in self.pipe.units: for unit in self.pipe.units:
@@ -111,7 +118,7 @@ if __name__ == "__main__":
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x, state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
) )
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate) optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=0.000001)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
launch_training_task( launch_training_task(
dataset, model, model_logger, optimizer, scheduler, dataset, model, model_logger, optimizer, scheduler,

View File

@@ -0,0 +1,38 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput
from diffsynth import load_state_dict
import torch
from PIL import Image
from diffsynth.controlnets.processors import Annotator
import os
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
ModelConfig(path="models/DiffSynth-Studio/BlockWiseControlnet/model_init.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
state_dict = load_state_dict("models/train/Qwen-Image-BlockWiseControlNet_full_lr1e-3_wd1e-6/step-26000.safetensors")
pipe.blockwise_controlnet.load_state_dict(state_dict)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = Image.open("test_image.jpg").convert("RGB").resize((1024, 1024))
canny_image = Annotator("canny")(image)
canny_image.save("canny_image_test.jpg")
controlnet_input = ControlNetInput(
image=canny_image,
scale=1.0,
processor_id="canny",
)
for seed in range(100, 200):
image = pipe(prompt, seed=seed, height=1024, width=1024, controlnet_inputs=[controlnet_input], num_inference_steps=30, cfg_scale=4.0)
image.save(f"test_image_controlnet_step2k_1_{seed}.jpg")