mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
160 lines
5.0 KiB
Python
160 lines
5.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from .qwen_image_dit import QwenEmbedRope, QwenImageTransformerBlock
|
|
from ..vram_management import gradient_checkpoint_forward
|
|
from einops import rearrange
|
|
from .sd3_dit import TimestepEmbeddings, RMSNorm
|
|
|
|
|
|
class QwenImageControlNet(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_layers: int = 60,
|
|
num_controlnet_layers: int = 6,
|
|
):
|
|
super().__init__()
|
|
|
|
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
|
|
|
|
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
|
|
self.txt_norm = RMSNorm(3584, eps=1e-6)
|
|
|
|
self.img_in = nn.Linear(64 * 2, 3072)
|
|
self.txt_in = nn.Linear(3584, 3072)
|
|
|
|
self.transformer_blocks = nn.ModuleList(
|
|
[
|
|
QwenImageTransformerBlock(
|
|
dim=3072,
|
|
num_attention_heads=24,
|
|
attention_head_dim=128,
|
|
)
|
|
for _ in range(num_controlnet_layers)
|
|
]
|
|
)
|
|
self.proj_out = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for i in range(num_layers)])
|
|
self.num_layers = num_layers
|
|
self.num_controlnet_layers = num_controlnet_layers
|
|
self.align_map = {i: i // (num_layers // num_controlnet_layers) for i in range(num_layers)}
|
|
|
|
|
|
def forward(
|
|
self,
|
|
latents=None,
|
|
timestep=None,
|
|
prompt_emb=None,
|
|
prompt_emb_mask=None,
|
|
height=None,
|
|
width=None,
|
|
controlnet_conditioning=None,
|
|
use_gradient_checkpointing=False,
|
|
use_gradient_checkpointing_offload=False,
|
|
**kwargs,
|
|
):
|
|
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
|
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
|
|
|
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
|
controlnet_conditioning = rearrange(controlnet_conditioning, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
|
image = torch.concat([image, controlnet_conditioning], dim=-1)
|
|
|
|
image = self.img_in(image)
|
|
text = self.txt_in(self.txt_norm(prompt_emb))
|
|
|
|
conditioning = self.time_text_embed(timestep, image.dtype)
|
|
|
|
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
|
|
|
outputs = []
|
|
for block in self.transformer_blocks:
|
|
text, image = gradient_checkpoint_forward(
|
|
block,
|
|
use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload,
|
|
image=image,
|
|
text=text,
|
|
temb=conditioning,
|
|
image_rotary_emb=image_rotary_emb,
|
|
)
|
|
outputs.append(image)
|
|
|
|
outputs_aligned = [self.proj_out[i](outputs[self.align_map[i]]) for i in range(self.num_layers)]
|
|
return outputs_aligned
|
|
|
|
@staticmethod
|
|
def state_dict_converter():
|
|
return QwenImageControlNetStateDictConverter()
|
|
|
|
|
|
|
|
class QwenImageControlNetStateDictConverter():
|
|
def __init__(self):
|
|
pass
|
|
|
|
def from_civitai(self, state_dict):
|
|
return state_dict
|
|
|
|
|
|
class BlockWiseControlBlock(torch.nn.Module):
|
|
# [linear, gelu, linear]
|
|
def __init__(self, dim: int = 3072):
|
|
super().__init__()
|
|
self.x_rms = RMSNorm(dim, eps=1e-6)
|
|
self.y_rms = RMSNorm(dim, eps=1e-6)
|
|
self.input_proj = nn.Linear(dim, dim)
|
|
self.act = nn.GELU()
|
|
self.output_proj = nn.Linear(dim, dim)
|
|
|
|
def forward(self, x, y):
|
|
x, y = self.x_rms(x), self.y_rms(y)
|
|
x = self.input_proj(x + y)
|
|
x = self.act(x)
|
|
x = self.output_proj(x)
|
|
return x
|
|
|
|
def init_weights(self):
|
|
# zero initialize output_proj
|
|
nn.init.zeros_(self.output_proj.weight)
|
|
nn.init.zeros_(self.output_proj.bias)
|
|
|
|
|
|
class QwenImageBlockWiseControlNet(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_layers: int = 60,
|
|
in_dim: int = 64,
|
|
dim: int = 3072,
|
|
):
|
|
super().__init__()
|
|
self.img_in = nn.Linear(in_dim, dim)
|
|
self.controlnet_blocks = nn.ModuleList(
|
|
[
|
|
BlockWiseControlBlock(dim)
|
|
for _ in range(num_layers)
|
|
]
|
|
)
|
|
|
|
def init_weight(self):
|
|
nn.init.zeros_(self.img_in.weight)
|
|
nn.init.zeros_(self.img_in.bias)
|
|
for block in self.controlnet_blocks:
|
|
block.init_weights()
|
|
|
|
def process_controlnet_conditioning(self, controlnet_conditioning):
|
|
return self.img_in(controlnet_conditioning)
|
|
|
|
def blockwise_forward(self, img, controlnet_conditioning, block_id):
|
|
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
|
|
|
|
@staticmethod
|
|
def state_dict_converter():
|
|
return QwenImageBlockWiseControlNetStateDictConverter()
|
|
|
|
|
|
class QwenImageBlockWiseControlNetStateDictConverter():
|
|
def __init__(self):
|
|
pass
|
|
|
|
def from_civitai(self, state_dict):
|
|
return state_dict
|