mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
Compare commits
6 Commits
qwen-image
...
qwen-image
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2abc97fc0f | ||
|
|
84ede171fd | ||
|
|
6f4e38276e | ||
|
|
829ca3414b | ||
|
|
26461c1963 | ||
|
|
0412fc7232 |
@@ -75,7 +75,6 @@ from ..models.nexus_gen import NexusGenAutoregressiveModel
|
|||||||
from ..models.qwen_image_dit import QwenImageDiT
|
from ..models.qwen_image_dit import QwenImageDiT
|
||||||
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||||
from ..models.qwen_image_vae import QwenImageVAE
|
from ..models.qwen_image_vae import QwenImageVAE
|
||||||
from ..models.qwen_image_controlnet import QwenImageControlNet
|
|
||||||
|
|
||||||
model_loader_configs = [
|
model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
@@ -168,7 +167,6 @@ model_loader_configs = [
|
|||||||
(None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"),
|
(None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"),
|
||||||
(None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"),
|
(None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"),
|
||||||
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
|
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
|
||||||
(None, "6834bf9ef1a6723291d82719fb02e953", ["qwen_image_controlnet"], [QwenImageControlNet], "civitai"),
|
|
||||||
]
|
]
|
||||||
huggingface_model_loader_configs = [
|
huggingface_model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
|
|||||||
@@ -383,5 +383,20 @@ class WanLoRAConverter:
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageLoRAConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def align_to_opensource_format(state_dict, **kwargs):
|
||||||
|
state_dict = {name.replace(".default.", "."): param for name, param in state_dict.items()}
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def align_to_diffsynth_format(state_dict, **kwargs):
|
||||||
|
state_dict = {name.replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def get_lora_loaders():
|
def get_lora_loaders():
|
||||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||||
|
|||||||
@@ -1,96 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from .qwen_image_dit import QwenEmbedRope, QwenImageTransformerBlock
|
|
||||||
from ..vram_management import gradient_checkpoint_forward
|
|
||||||
from einops import rearrange
|
|
||||||
from .sd3_dit import TimestepEmbeddings, RMSNorm
|
|
||||||
|
|
||||||
|
|
||||||
class QwenImageControlNet(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_layers: int = 60,
|
|
||||||
num_controlnet_layers: int = 6,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
|
|
||||||
|
|
||||||
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
|
|
||||||
self.txt_norm = RMSNorm(3584, eps=1e-6)
|
|
||||||
|
|
||||||
self.img_in = nn.Linear(64 * 2, 3072)
|
|
||||||
self.txt_in = nn.Linear(3584, 3072)
|
|
||||||
|
|
||||||
self.transformer_blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
QwenImageTransformerBlock(
|
|
||||||
dim=3072,
|
|
||||||
num_attention_heads=24,
|
|
||||||
attention_head_dim=128,
|
|
||||||
)
|
|
||||||
for _ in range(num_controlnet_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.alpha = torch.nn.Parameter(torch.zeros((num_layers,)))
|
|
||||||
self.num_layers = num_layers
|
|
||||||
self.num_controlnet_layers = num_controlnet_layers
|
|
||||||
self.align_map = {i: i // (num_layers // num_controlnet_layers) for i in range(num_layers)}
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
latents=None,
|
|
||||||
timestep=None,
|
|
||||||
prompt_emb=None,
|
|
||||||
prompt_emb_mask=None,
|
|
||||||
height=None,
|
|
||||||
width=None,
|
|
||||||
controlnet_conditioning=None,
|
|
||||||
use_gradient_checkpointing=False,
|
|
||||||
use_gradient_checkpointing_offload=False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
|
||||||
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
|
||||||
|
|
||||||
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (P Q C)", H=height//16, W=width//16, P=2, Q=2)
|
|
||||||
controlnet_conditioning = rearrange(controlnet_conditioning, "B C (H P) (W Q) -> B (H W) (P Q C)", H=height//16, W=width//16, P=2, Q=2)
|
|
||||||
image = torch.concat([image, controlnet_conditioning], dim=-1)
|
|
||||||
|
|
||||||
image = self.img_in(image)
|
|
||||||
text = self.txt_in(self.txt_norm(prompt_emb))
|
|
||||||
|
|
||||||
conditioning = self.time_text_embed(timestep, image.dtype)
|
|
||||||
|
|
||||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
for block in self.transformer_blocks:
|
|
||||||
text, image = gradient_checkpoint_forward(
|
|
||||||
block,
|
|
||||||
use_gradient_checkpointing,
|
|
||||||
use_gradient_checkpointing_offload,
|
|
||||||
image=image,
|
|
||||||
text=text,
|
|
||||||
temb=conditioning,
|
|
||||||
image_rotary_emb=image_rotary_emb,
|
|
||||||
)
|
|
||||||
outputs.append(image)
|
|
||||||
|
|
||||||
alpha = self.alpha.to(dtype=image.dtype, device=image.device)
|
|
||||||
outputs_aligned = [outputs[self.align_map[i]] * alpha[i] for i in range(self.num_layers)]
|
|
||||||
return outputs_aligned
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_dict_converter():
|
|
||||||
return QwenImageControlNetStateDictConverter()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class QwenImageControlNetStateDictConverter():
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
|
||||||
return state_dict
|
|
||||||
@@ -335,7 +335,7 @@ class WanModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.control_adapter = None
|
self.control_adapter = None
|
||||||
|
|
||||||
def patchify(self, x: torch.Tensor,control_camera_latents_input: torch.Tensor = None):
|
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None):
|
||||||
x = self.patch_embedding(x)
|
x = self.patch_embedding(x)
|
||||||
if self.control_adapter is not None and control_camera_latents_input is not None:
|
if self.control_adapter is not None and control_camera_latents_input is not None:
|
||||||
y_camera = self.control_adapter(control_camera_latents_input)
|
y_camera = self.control_adapter(control_camera_latents_input)
|
||||||
|
|||||||
@@ -4,54 +4,19 @@ from typing import Union
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from ..models import ModelManager, load_state_dict
|
from ..models import ModelManager, load_state_dict
|
||||||
from ..models.qwen_image_dit import QwenImageDiT
|
from ..models.qwen_image_dit import QwenImageDiT
|
||||||
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||||
from ..models.qwen_image_vae import QwenImageVAE
|
from ..models.qwen_image_vae import QwenImageVAE
|
||||||
from ..models.qwen_image_controlnet import QwenImageControlNet
|
|
||||||
from ..schedulers import FlowMatchScheduler
|
from ..schedulers import FlowMatchScheduler
|
||||||
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
from ..utils import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||||
from ..lora import GeneralLoRALoader
|
from ..lora import GeneralLoRALoader
|
||||||
from .flux_image_new import ControlNetInput
|
|
||||||
|
|
||||||
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class QwenImageMultiControlNet(torch.nn.Module):
|
|
||||||
def __init__(self, models: list[QwenImageControlNet]):
|
|
||||||
super().__init__()
|
|
||||||
if not isinstance(models, list):
|
|
||||||
models = [models]
|
|
||||||
self.models = torch.nn.ModuleList(models)
|
|
||||||
|
|
||||||
def process_single_controlnet(self, controlnet_input: ControlNetInput, conditioning: torch.Tensor, **kwargs):
|
|
||||||
model = self.models[controlnet_input.controlnet_id]
|
|
||||||
res_stack = model(
|
|
||||||
controlnet_conditioning=conditioning,
|
|
||||||
processor_id=controlnet_input.processor_id,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
res_stack = [res * controlnet_input.scale for res in res_stack]
|
|
||||||
return res_stack
|
|
||||||
|
|
||||||
def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[ControlNetInput], progress_id, num_inference_steps, **kwargs):
|
|
||||||
res_stack = None
|
|
||||||
for controlnet_input, conditioning in zip(controlnet_inputs, conditionings):
|
|
||||||
progress = (num_inference_steps - 1 - progress_id) / max(num_inference_steps - 1, 1)
|
|
||||||
if progress > controlnet_input.start or progress < controlnet_input.end:
|
|
||||||
continue
|
|
||||||
res_stack_ = self.process_single_controlnet(controlnet_input, conditioning, **kwargs)
|
|
||||||
if res_stack is None:
|
|
||||||
res_stack = res_stack_
|
|
||||||
else:
|
|
||||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
|
||||||
return res_stack
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class QwenImagePipeline(BasePipeline):
|
class QwenImagePipeline(BasePipeline):
|
||||||
|
|
||||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||||
@@ -65,16 +30,14 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
self.text_encoder: QwenImageTextEncoder = None
|
self.text_encoder: QwenImageTextEncoder = None
|
||||||
self.dit: QwenImageDiT = None
|
self.dit: QwenImageDiT = None
|
||||||
self.vae: QwenImageVAE = None
|
self.vae: QwenImageVAE = None
|
||||||
self.controlnet: QwenImageMultiControlNet = None
|
|
||||||
self.tokenizer: Qwen2Tokenizer = None
|
self.tokenizer: Qwen2Tokenizer = None
|
||||||
self.unit_runner = PipelineUnitRunner()
|
self.unit_runner = PipelineUnitRunner()
|
||||||
self.in_iteration_models = ("dit", "controlnet")
|
self.in_iteration_models = ("dit",)
|
||||||
self.units = [
|
self.units = [
|
||||||
QwenImageUnit_ShapeChecker(),
|
QwenImageUnit_ShapeChecker(),
|
||||||
QwenImageUnit_NoiseInitializer(),
|
QwenImageUnit_NoiseInitializer(),
|
||||||
QwenImageUnit_InputImageEmbedder(),
|
QwenImageUnit_InputImageEmbedder(),
|
||||||
QwenImageUnit_PromptEmbedder(),
|
QwenImageUnit_PromptEmbedder(),
|
||||||
QwenImageUnit_ControlNet(),
|
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_qwen_image
|
self.model_fn = model_fn_qwen_image
|
||||||
|
|
||||||
@@ -202,7 +165,6 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder")
|
pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder")
|
||||||
pipe.dit = model_manager.fetch_model("qwen_image_dit")
|
pipe.dit = model_manager.fetch_model("qwen_image_dit")
|
||||||
pipe.vae = model_manager.fetch_model("qwen_image_vae")
|
pipe.vae = model_manager.fetch_model("qwen_image_vae")
|
||||||
pipe.controlnet = QwenImageMultiControlNet(model_manager.fetch_model("qwen_image_controlnet", index="all"))
|
|
||||||
if tokenizer_config is not None and pipe.text_encoder is not None:
|
if tokenizer_config is not None and pipe.text_encoder is not None:
|
||||||
tokenizer_config.download_if_necessary()
|
tokenizer_config.download_if_necessary()
|
||||||
from transformers import Qwen2Tokenizer
|
from transformers import Qwen2Tokenizer
|
||||||
@@ -228,8 +190,6 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
rand_device: str = "cpu",
|
rand_device: str = "cpu",
|
||||||
# Steps
|
# Steps
|
||||||
num_inference_steps: int = 30,
|
num_inference_steps: int = 30,
|
||||||
# ControlNet
|
|
||||||
controlnet_inputs: list[ControlNetInput] = None,
|
|
||||||
# Tile
|
# Tile
|
||||||
tiled: bool = False,
|
tiled: bool = False,
|
||||||
tile_size: int = 128,
|
tile_size: int = 128,
|
||||||
@@ -252,8 +212,6 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||||
"height": height, "width": width,
|
"height": height, "width": width,
|
||||||
"seed": seed, "rand_device": rand_device,
|
"seed": seed, "rand_device": rand_device,
|
||||||
"num_inference_steps": num_inference_steps,
|
|
||||||
"controlnet_inputs": controlnet_inputs,
|
|
||||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||||
}
|
}
|
||||||
for unit in self.units:
|
for unit in self.units:
|
||||||
@@ -365,82 +323,18 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class QwenImageUnit_ControlNet(PipelineUnit):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"),
|
|
||||||
onload_model_names=("vae",)
|
|
||||||
)
|
|
||||||
|
|
||||||
def apply_controlnet_mask_on_latents(self, pipe, latents, mask):
|
|
||||||
mask = (pipe.preprocess_image(mask) + 1) / 2
|
|
||||||
mask = mask.mean(dim=1, keepdim=True)
|
|
||||||
mask = 1 - torch.nn.functional.interpolate(mask, size=latents.shape[-2:])
|
|
||||||
latents = torch.concat([latents, mask], dim=1)
|
|
||||||
return latents
|
|
||||||
|
|
||||||
def apply_controlnet_mask_on_image(self, pipe, image, mask):
|
|
||||||
mask = mask.resize(image.size)
|
|
||||||
mask = pipe.preprocess_image(mask).mean(dim=[0, 1]).cpu()
|
|
||||||
image = np.array(image)
|
|
||||||
image[mask > 0] = 0
|
|
||||||
image = Image.fromarray(image)
|
|
||||||
return image
|
|
||||||
|
|
||||||
def process(self, pipe: QwenImagePipeline, controlnet_inputs: list[ControlNetInput], tiled, tile_size, tile_stride):
|
|
||||||
if controlnet_inputs is None:
|
|
||||||
return {}
|
|
||||||
pipe.load_models_to_device(self.onload_model_names)
|
|
||||||
conditionings = []
|
|
||||||
for controlnet_input in controlnet_inputs:
|
|
||||||
image = controlnet_input.image
|
|
||||||
if controlnet_input.inpaint_mask is not None:
|
|
||||||
image = self.apply_controlnet_mask_on_image(pipe, image, controlnet_input.inpaint_mask)
|
|
||||||
|
|
||||||
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
|
||||||
image = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
|
||||||
|
|
||||||
if controlnet_input.inpaint_mask is not None:
|
|
||||||
image = self.apply_controlnet_mask_on_latents(pipe, image, controlnet_input.inpaint_mask)
|
|
||||||
conditionings.append(image)
|
|
||||||
return {"controlnet_conditionings": conditionings}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def model_fn_qwen_image(
|
def model_fn_qwen_image(
|
||||||
dit: QwenImageDiT = None,
|
dit: QwenImageDiT = None,
|
||||||
controlnet: QwenImageMultiControlNet = None,
|
|
||||||
latents=None,
|
latents=None,
|
||||||
timestep=None,
|
timestep=None,
|
||||||
prompt_emb=None,
|
prompt_emb=None,
|
||||||
prompt_emb_mask=None,
|
prompt_emb_mask=None,
|
||||||
height=None,
|
height=None,
|
||||||
width=None,
|
width=None,
|
||||||
controlnet_inputs=None,
|
|
||||||
controlnet_conditionings=None,
|
|
||||||
progress_id=0,
|
|
||||||
num_inference_steps=1,
|
|
||||||
use_gradient_checkpointing=False,
|
use_gradient_checkpointing=False,
|
||||||
use_gradient_checkpointing_offload=False,
|
use_gradient_checkpointing_offload=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
# ControlNet
|
|
||||||
if controlnet_conditionings is not None:
|
|
||||||
controlnet_extra_kwargs = {
|
|
||||||
"latents": latents,
|
|
||||||
"timestep": timestep,
|
|
||||||
"prompt_emb": prompt_emb,
|
|
||||||
"prompt_emb_mask": prompt_emb_mask,
|
|
||||||
"height": height,
|
|
||||||
"width": width,
|
|
||||||
"use_gradient_checkpointing": use_gradient_checkpointing,
|
|
||||||
"use_gradient_checkpointing_offload": use_gradient_checkpointing_offload,
|
|
||||||
}
|
|
||||||
res_stack = controlnet(
|
|
||||||
controlnet_conditionings, controlnet_inputs, progress_id, num_inference_steps,
|
|
||||||
**controlnet_extra_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
||||||
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||||
timestep = timestep / 1000
|
timestep = timestep / 1000
|
||||||
@@ -452,7 +346,7 @@ def model_fn_qwen_image(
|
|||||||
conditioning = dit.time_text_embed(timestep, image.dtype)
|
conditioning = dit.time_text_embed(timestep, image.dtype)
|
||||||
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||||
|
|
||||||
for block_id, block in enumerate(dit.transformer_blocks):
|
for block in dit.transformer_blocks:
|
||||||
text, image = gradient_checkpoint_forward(
|
text, image = gradient_checkpoint_forward(
|
||||||
block,
|
block,
|
||||||
use_gradient_checkpointing,
|
use_gradient_checkpointing,
|
||||||
@@ -462,8 +356,6 @@ def model_fn_qwen_image(
|
|||||||
temb=conditioning,
|
temb=conditioning,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
)
|
)
|
||||||
if controlnet_inputs is not None:
|
|
||||||
image = image + res_stack[block_id]
|
|
||||||
|
|
||||||
image = dit.norm_out(image, conditioning)
|
image = dit.norm_out(image, conditioning)
|
||||||
image = dit.proj_out(image)
|
image = dit.proj_out(image)
|
||||||
|
|||||||
@@ -110,8 +110,47 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
|||||||
self.lora_A_weights = []
|
self.lora_A_weights = []
|
||||||
self.lora_B_weights = []
|
self.lora_B_weights = []
|
||||||
self.lora_merger = None
|
self.lora_merger = None
|
||||||
|
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
|
||||||
|
|
||||||
|
def fp8_linear(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
device = input.device
|
||||||
|
origin_dtype = input.dtype
|
||||||
|
origin_shape = input.shape
|
||||||
|
input = input.reshape(-1, origin_shape[-1])
|
||||||
|
|
||||||
|
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
|
||||||
|
fp8_max = 448.0
|
||||||
|
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
|
||||||
|
# To avoid overflow and ensure numerical compatibility during FP8 computation,
|
||||||
|
# we scale down the input by 2.0 in advance.
|
||||||
|
# This scaling will be compensated later during the final result scaling.
|
||||||
|
if self.computation_dtype == torch.float8_e4m3fnuz:
|
||||||
|
fp8_max = fp8_max / 2.0
|
||||||
|
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
|
||||||
|
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
|
||||||
|
input = input / (scale_a + 1e-8)
|
||||||
|
input = input.to(self.computation_dtype)
|
||||||
|
weight = weight.to(self.computation_dtype)
|
||||||
|
|
||||||
|
result = torch._scaled_mm(
|
||||||
|
input,
|
||||||
|
weight.T,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b.T,
|
||||||
|
bias=bias,
|
||||||
|
out_dtype=origin_dtype,
|
||||||
|
)
|
||||||
|
new_shape = origin_shape[:-1] + result.shape[-1:]
|
||||||
|
result = result.reshape(new_shape)
|
||||||
|
return result
|
||||||
|
|
||||||
def forward(self, x, *args, **kwargs):
|
def forward(self, x, *args, **kwargs):
|
||||||
|
# VRAM management
|
||||||
if self.state == 2:
|
if self.state == 2:
|
||||||
weight, bias = self.weight, self.bias
|
weight, bias = self.weight, self.bias
|
||||||
else:
|
else:
|
||||||
@@ -123,8 +162,14 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
|||||||
else:
|
else:
|
||||||
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||||
out = torch.nn.functional.linear(x, weight, bias)
|
|
||||||
|
|
||||||
|
# Linear forward
|
||||||
|
if self.enable_fp8:
|
||||||
|
out = self.fp8_linear(x, weight, bias)
|
||||||
|
else:
|
||||||
|
out = torch.nn.functional.linear(x, weight, bias)
|
||||||
|
|
||||||
|
# LoRA
|
||||||
if len(self.lora_A_weights) == 0:
|
if len(self.lora_A_weights) == 0:
|
||||||
# No LoRA
|
# No LoRA
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -1,34 +0,0 @@
|
|||||||
accelerate launch examples/qwen_image/model_training/train.py \
|
|
||||||
--dataset_base_path data/example_image_dataset \
|
|
||||||
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \
|
|
||||||
--data_file_keys "image,controlnet_image" \
|
|
||||||
--max_pixels 1048576 \
|
|
||||||
--dataset_repeat 8000 \
|
|
||||||
--model_paths '[
|
|
||||||
[
|
|
||||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors",
|
|
||||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors",
|
|
||||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors",
|
|
||||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors",
|
|
||||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors",
|
|
||||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors",
|
|
||||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors",
|
|
||||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors",
|
|
||||||
"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors"
|
|
||||||
],
|
|
||||||
[
|
|
||||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
|
||||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
|
||||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
|
||||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
|
||||||
],
|
|
||||||
"models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors",
|
|
||||||
"models/controlnet.safetensors"
|
|
||||||
]' \
|
|
||||||
--learning_rate 1e-5 \
|
|
||||||
--num_epochs 2 \
|
|
||||||
--remove_prefix_in_ckpt "pipe.controlnet.models.0." \
|
|
||||||
--output_path "./models/train/Qwen-Image-ControlNet_full" \
|
|
||||||
--trainable_models "controlnet" \
|
|
||||||
--extra_inputs "controlnet_image" \
|
|
||||||
--use_gradient_checkpointing
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
# This script is for initializing a Qwen-Image-ControlNet
|
|
||||||
from diffsynth import load_state_dict, hash_state_dict_keys
|
|
||||||
from diffsynth.pipelines.qwen_image import QwenImageControlNet
|
|
||||||
import torch
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
|
|
||||||
|
|
||||||
state_dict_dit = {}
|
|
||||||
for i in range(1, 10):
|
|
||||||
state_dict_dit.update(load_state_dict(f"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-0000{i}-of-00009.safetensors", torch_dtype=torch.bfloat16, device="cuda"))
|
|
||||||
|
|
||||||
controlnet = QwenImageControlNet().to(dtype=torch.bfloat16, device="cuda")
|
|
||||||
state_dict_controlnet = controlnet.state_dict()
|
|
||||||
|
|
||||||
state_dict_init = {}
|
|
||||||
for k in state_dict_controlnet:
|
|
||||||
if k in state_dict_dit:
|
|
||||||
if state_dict_dit[k].shape == state_dict_controlnet[k].shape:
|
|
||||||
state_dict_init[k] = state_dict_dit[k]
|
|
||||||
elif k == "img_in.weight":
|
|
||||||
state_dict_init[k] = torch.concat(
|
|
||||||
[
|
|
||||||
state_dict_dit[k],
|
|
||||||
state_dict_dit[k],
|
|
||||||
],
|
|
||||||
dim=-1
|
|
||||||
)
|
|
||||||
elif k == "alpha":
|
|
||||||
state_dict_init[k] = torch.zeros_like(state_dict_controlnet[k])
|
|
||||||
controlnet.load_state_dict(state_dict_init)
|
|
||||||
|
|
||||||
print(hash_state_dict_keys(state_dict_controlnet))
|
|
||||||
save_file(state_dict_controlnet, "models/controlnet.safetensors")
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
import torch, os, json
|
import torch, os, json
|
||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser
|
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser
|
||||||
|
from diffsynth.models.lora import QwenImageLoRAConverter
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
@@ -72,14 +73,8 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Extra inputs
|
# Extra inputs
|
||||||
controlnet_input = {}
|
|
||||||
for extra_input in self.extra_inputs:
|
for extra_input in self.extra_inputs:
|
||||||
if extra_input.startswith("controlnet_"):
|
inputs_shared[extra_input] = data[extra_input]
|
||||||
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
|
|
||||||
else:
|
|
||||||
inputs_shared[extra_input] = data[extra_input]
|
|
||||||
if len(controlnet_input) > 0:
|
|
||||||
inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)]
|
|
||||||
|
|
||||||
# Pipeline units will automatically process the input parameters.
|
# Pipeline units will automatically process the input parameters.
|
||||||
for unit in self.pipe.units:
|
for unit in self.pipe.units:
|
||||||
@@ -114,6 +109,7 @@ if __name__ == "__main__":
|
|||||||
model_logger = ModelLogger(
|
model_logger = ModelLogger(
|
||||||
args.output_path,
|
args.output_path,
|
||||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||||
|
state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
|
||||||
)
|
)
|
||||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
|
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
|
||||||
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
||||||
|
|||||||
Reference in New Issue
Block a user