mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:23:43 +00:00
548 lines
23 KiB
Python
548 lines
23 KiB
Python
import torch, math
|
|
import torch.nn as nn
|
|
from typing import Tuple, Optional, Union, List
|
|
from einops import rearrange
|
|
from .sd3_dit import TimestepEmbeddings, RMSNorm
|
|
from .flux_dit import AdaLayerNorm
|
|
|
|
try:
|
|
import flash_attn_interface
|
|
FLASH_ATTN_3_AVAILABLE = True
|
|
except ModuleNotFoundError:
|
|
FLASH_ATTN_3_AVAILABLE = False
|
|
|
|
|
|
def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False):
|
|
if FLASH_ATTN_3_AVAILABLE and attention_mask is None:
|
|
if not enable_fp8_attention:
|
|
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
|
|
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
|
|
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
|
|
x = flash_attn_interface.flash_attn_func(q, k, v)
|
|
if isinstance(x, tuple):
|
|
x = x[0]
|
|
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
|
else:
|
|
origin_dtype = q.dtype
|
|
q_std, k_std, v_std = q.std(), k.std(), v.std()
|
|
q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn)
|
|
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
|
|
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
|
|
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
|
|
x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1)))
|
|
if isinstance(x, tuple):
|
|
x = x[0]
|
|
x = x.to(origin_dtype) * v_std
|
|
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
|
else:
|
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
|
|
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
|
return x
|
|
|
|
|
|
class ApproximateGELU(nn.Module):
|
|
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
|
super().__init__()
|
|
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.proj(x)
|
|
return x * torch.sigmoid(1.702 * x)
|
|
|
|
def apply_rotary_emb_qwen(
|
|
x: torch.Tensor,
|
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]
|
|
):
|
|
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
|
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
|
return x_out.type_as(x)
|
|
|
|
|
|
class QwenEmbedRope(nn.Module):
|
|
def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
|
|
super().__init__()
|
|
self.theta = theta
|
|
self.axes_dim = axes_dim
|
|
pos_index = torch.arange(4096)
|
|
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
|
self.pos_freqs = torch.cat([
|
|
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
|
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
|
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
|
], dim=1)
|
|
self.neg_freqs = torch.cat([
|
|
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
|
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
|
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
|
], dim=1)
|
|
self.rope_cache = {}
|
|
self.scale_rope = scale_rope
|
|
|
|
def rope_params(self, index, dim, theta=10000):
|
|
"""
|
|
Args:
|
|
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
|
"""
|
|
assert dim % 2 == 0
|
|
freqs = torch.outer(
|
|
index,
|
|
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
|
|
)
|
|
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
|
return freqs
|
|
|
|
|
|
def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens):
|
|
if isinstance(video_fhw, list):
|
|
video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3))
|
|
_, height, width = video_fhw
|
|
if self.scale_rope:
|
|
max_vid_index = max(height // 2, width // 2)
|
|
else:
|
|
max_vid_index = max(height, width)
|
|
required_len = max_vid_index + max(txt_seq_lens)
|
|
cur_max_len = self.pos_freqs.shape[0]
|
|
if required_len <= cur_max_len:
|
|
return
|
|
|
|
new_max_len = math.ceil(required_len / 512) * 512
|
|
pos_index = torch.arange(new_max_len)
|
|
neg_index = torch.arange(new_max_len).flip(0) * -1 - 1
|
|
self.pos_freqs = torch.cat([
|
|
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
|
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
|
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
|
], dim=1)
|
|
self.neg_freqs = torch.cat([
|
|
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
|
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
|
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
|
], dim=1)
|
|
return
|
|
|
|
|
|
def forward(self, video_fhw, txt_seq_lens, device):
|
|
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
|
|
if self.pos_freqs.device != device:
|
|
self.pos_freqs = self.pos_freqs.to(device)
|
|
self.neg_freqs = self.neg_freqs.to(device)
|
|
|
|
vid_freqs = []
|
|
max_vid_index = 0
|
|
for idx, fhw in enumerate(video_fhw):
|
|
frame, height, width = fhw
|
|
rope_key = f"{idx}_{height}_{width}"
|
|
|
|
if rope_key not in self.rope_cache:
|
|
seq_lens = frame * height * width
|
|
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
|
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
|
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
|
if self.scale_rope:
|
|
freqs_height = torch.cat(
|
|
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
|
)
|
|
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
|
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
|
|
else:
|
|
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
|
|
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
|
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
|
vid_freqs.append(self.rope_cache[rope_key])
|
|
|
|
if self.scale_rope:
|
|
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
|
else:
|
|
max_vid_index = max(height, width, max_vid_index)
|
|
|
|
max_len = max(txt_seq_lens)
|
|
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
|
vid_freqs = torch.cat(vid_freqs, dim=0)
|
|
|
|
return vid_freqs, txt_freqs
|
|
|
|
|
|
def forward_sampling(self, video_fhw, txt_seq_lens, device):
|
|
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
|
|
if self.pos_freqs.device != device:
|
|
self.pos_freqs = self.pos_freqs.to(device)
|
|
self.neg_freqs = self.neg_freqs.to(device)
|
|
|
|
vid_freqs = []
|
|
max_vid_index = 0
|
|
for idx, fhw in enumerate(video_fhw):
|
|
frame, height, width = fhw
|
|
rope_key = f"{idx}_{height}_{width}"
|
|
if idx > 0 and f"{0}_{height}_{width}" not in self.rope_cache:
|
|
frame_0, height_0, width_0 = video_fhw[0]
|
|
|
|
rope_key_0 = f"0_{height_0}_{width_0}"
|
|
spatial_freqs_0 = self.rope_cache[rope_key_0].reshape(frame_0, height_0, width_0, -1)
|
|
h_indices = torch.linspace(0, height_0 - 1, height).long()
|
|
w_indices = torch.linspace(0, width_0 - 1, width).long()
|
|
h_grid, w_grid = torch.meshgrid(h_indices, w_indices, indexing='ij')
|
|
sampled_rope = spatial_freqs_0[:, h_grid, w_grid, :]
|
|
|
|
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
|
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
|
sampled_rope[:, :, :, :freqs_frame.shape[-1]] = freqs_frame
|
|
|
|
seq_lens = frame * height * width
|
|
self.rope_cache[rope_key] = sampled_rope.reshape(seq_lens, -1).clone()
|
|
if rope_key not in self.rope_cache:
|
|
seq_lens = frame * height * width
|
|
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
|
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
|
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
|
if self.scale_rope:
|
|
freqs_height = torch.cat(
|
|
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
|
)
|
|
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
|
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
|
|
else:
|
|
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
|
|
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
|
self.rope_cache[rope_key] = freqs.clone()
|
|
vid_freqs.append(self.rope_cache[rope_key].contiguous())
|
|
|
|
if self.scale_rope:
|
|
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
|
else:
|
|
max_vid_index = max(height, width, max_vid_index)
|
|
|
|
max_len = max(txt_seq_lens)
|
|
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
|
vid_freqs = torch.cat(vid_freqs, dim=0)
|
|
|
|
return vid_freqs, txt_freqs
|
|
|
|
|
|
class QwenFeedForward(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
dim_out: Optional[int] = None,
|
|
dropout: float = 0.0,
|
|
):
|
|
super().__init__()
|
|
inner_dim = int(dim * 4)
|
|
self.net = nn.ModuleList([])
|
|
self.net.append(ApproximateGELU(dim, inner_dim))
|
|
self.net.append(nn.Dropout(dropout))
|
|
self.net.append(nn.Linear(inner_dim, dim_out))
|
|
|
|
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
|
for module in self.net:
|
|
hidden_states = module(hidden_states)
|
|
return hidden_states
|
|
|
|
class QwenDoubleStreamAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim_a,
|
|
dim_b,
|
|
num_heads,
|
|
head_dim,
|
|
):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.head_dim = head_dim
|
|
|
|
self.to_q = nn.Linear(dim_a, dim_a)
|
|
self.to_k = nn.Linear(dim_a, dim_a)
|
|
self.to_v = nn.Linear(dim_a, dim_a)
|
|
self.norm_q = RMSNorm(head_dim, eps=1e-6)
|
|
self.norm_k = RMSNorm(head_dim, eps=1e-6)
|
|
|
|
self.add_q_proj = nn.Linear(dim_b, dim_b)
|
|
self.add_k_proj = nn.Linear(dim_b, dim_b)
|
|
self.add_v_proj = nn.Linear(dim_b, dim_b)
|
|
self.norm_added_q = RMSNorm(head_dim, eps=1e-6)
|
|
self.norm_added_k = RMSNorm(head_dim, eps=1e-6)
|
|
|
|
self.to_out = torch.nn.Sequential(nn.Linear(dim_a, dim_a))
|
|
self.to_add_out = nn.Linear(dim_b, dim_b)
|
|
|
|
def forward(
|
|
self,
|
|
image: torch.FloatTensor,
|
|
text: torch.FloatTensor,
|
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
enable_fp8_attention: bool = False,
|
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
|
|
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
|
|
seq_txt = txt_q.shape[1]
|
|
|
|
img_q = rearrange(img_q, 'b s (h d) -> b h s d', h=self.num_heads)
|
|
img_k = rearrange(img_k, 'b s (h d) -> b h s d', h=self.num_heads)
|
|
img_v = rearrange(img_v, 'b s (h d) -> b h s d', h=self.num_heads)
|
|
|
|
txt_q = rearrange(txt_q, 'b s (h d) -> b h s d', h=self.num_heads)
|
|
txt_k = rearrange(txt_k, 'b s (h d) -> b h s d', h=self.num_heads)
|
|
txt_v = rearrange(txt_v, 'b s (h d) -> b h s d', h=self.num_heads)
|
|
|
|
img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
|
|
txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
|
|
|
|
if image_rotary_emb is not None:
|
|
img_freqs, txt_freqs = image_rotary_emb
|
|
img_q = apply_rotary_emb_qwen(img_q, img_freqs)
|
|
img_k = apply_rotary_emb_qwen(img_k, img_freqs)
|
|
txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
|
|
txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
|
|
|
|
joint_q = torch.cat([txt_q, img_q], dim=2)
|
|
joint_k = torch.cat([txt_k, img_k], dim=2)
|
|
joint_v = torch.cat([txt_v, img_v], dim=2)
|
|
|
|
joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)
|
|
|
|
txt_attn_output = joint_attn_out[:, :seq_txt, :]
|
|
img_attn_output = joint_attn_out[:, seq_txt:, :]
|
|
|
|
img_attn_output = self.to_out(img_attn_output)
|
|
txt_attn_output = self.to_add_out(txt_attn_output)
|
|
|
|
return img_attn_output, txt_attn_output
|
|
|
|
|
|
class QwenImageTransformerBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_attention_heads: int,
|
|
attention_head_dim: int,
|
|
eps: float = 1e-6,
|
|
):
|
|
super().__init__()
|
|
|
|
self.dim = dim
|
|
self.num_attention_heads = num_attention_heads
|
|
self.attention_head_dim = attention_head_dim
|
|
|
|
self.img_mod = nn.Sequential(
|
|
nn.SiLU(),
|
|
nn.Linear(dim, 6 * dim),
|
|
)
|
|
self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
|
self.attn = QwenDoubleStreamAttention(
|
|
dim_a=dim,
|
|
dim_b=dim,
|
|
num_heads=num_attention_heads,
|
|
head_dim=attention_head_dim,
|
|
)
|
|
self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
|
self.img_mlp = QwenFeedForward(dim=dim, dim_out=dim)
|
|
|
|
self.txt_mod = nn.Sequential(
|
|
nn.SiLU(),
|
|
nn.Linear(dim, 6 * dim, bias=True),
|
|
)
|
|
self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
|
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
|
self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim)
|
|
|
|
def _modulate(self, x, mod_params):
|
|
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
|
|
|
def forward(
|
|
self,
|
|
image: torch.Tensor,
|
|
text: torch.Tensor,
|
|
temb: torch.Tensor,
|
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
enable_fp8_attention = False,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
|
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
|
|
|
img_normed = self.img_norm1(image)
|
|
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn)
|
|
|
|
txt_normed = self.txt_norm1(text)
|
|
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
|
|
|
|
img_attn_out, txt_attn_out = self.attn(
|
|
image=img_modulated,
|
|
text=txt_modulated,
|
|
image_rotary_emb=image_rotary_emb,
|
|
attention_mask=attention_mask,
|
|
enable_fp8_attention=enable_fp8_attention,
|
|
)
|
|
|
|
image = image + img_gate * img_attn_out
|
|
text = text + txt_gate * txt_attn_out
|
|
|
|
img_normed_2 = self.img_norm2(image)
|
|
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp)
|
|
|
|
txt_normed_2 = self.txt_norm2(text)
|
|
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
|
|
|
|
img_mlp_out = self.img_mlp(img_modulated_2)
|
|
txt_mlp_out = self.txt_mlp(txt_modulated_2)
|
|
|
|
image = image + img_gate_2 * img_mlp_out
|
|
text = text + txt_gate_2 * txt_mlp_out
|
|
|
|
return text, image
|
|
|
|
|
|
class QwenImageDiT(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_layers: int = 60,
|
|
):
|
|
super().__init__()
|
|
|
|
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16,56,56], scale_rope=True)
|
|
|
|
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
|
|
self.txt_norm = RMSNorm(3584, eps=1e-6)
|
|
|
|
self.img_in = nn.Linear(64, 3072)
|
|
self.txt_in = nn.Linear(3584, 3072)
|
|
|
|
self.transformer_blocks = nn.ModuleList(
|
|
[
|
|
QwenImageTransformerBlock(
|
|
dim=3072,
|
|
num_attention_heads=24,
|
|
attention_head_dim=128,
|
|
)
|
|
for _ in range(num_layers)
|
|
]
|
|
)
|
|
self.norm_out = AdaLayerNorm(3072, single=True)
|
|
self.proj_out = nn.Linear(3072, 64)
|
|
|
|
|
|
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, entity_prompt_emb_mask, entity_masks, height, width, image, img_shapes):
|
|
# prompt_emb
|
|
all_prompt_emb = entity_prompt_emb + [prompt_emb]
|
|
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
|
|
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
|
|
|
|
# image_rotary_emb
|
|
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
|
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
|
entity_seq_lens = [emb_mask.sum(dim=1).tolist() for emb_mask in entity_prompt_emb_mask]
|
|
entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens]
|
|
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
|
|
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
|
|
|
|
# attention_mask
|
|
repeat_dim = latents.shape[1]
|
|
max_masks = entity_masks.shape[1]
|
|
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
|
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
|
global_mask = torch.ones_like(entity_masks[0]).to(device=latents.device, dtype=latents.dtype)
|
|
entity_masks = entity_masks + [global_mask]
|
|
|
|
N = len(entity_masks)
|
|
batch_size = entity_masks[0].shape[0]
|
|
seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()]
|
|
total_seq_len = sum(seq_lens) + image.shape[1]
|
|
patched_masks = []
|
|
for i in range(N):
|
|
patched_mask = rearrange(entity_masks[i], "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
|
patched_masks.append(patched_mask)
|
|
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
|
|
|
|
# prompt-image attention mask
|
|
image_start = sum(seq_lens)
|
|
image_end = total_seq_len
|
|
cumsum = [0]
|
|
single_image_seq = image_end - image_start
|
|
for length in seq_lens:
|
|
cumsum.append(cumsum[-1] + length)
|
|
for i in range(N):
|
|
prompt_start = cumsum[i]
|
|
prompt_end = cumsum[i+1]
|
|
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
|
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
|
|
# repeat image mask to match the single image sequence length
|
|
repeat_time = single_image_seq // image_mask.shape[-1]
|
|
image_mask = image_mask.repeat(1, 1, repeat_time)
|
|
# prompt update with image
|
|
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
|
# image update with prompt
|
|
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
|
# prompt-prompt attention mask, let the prompt tokens not attend to each other
|
|
for i in range(N):
|
|
for j in range(N):
|
|
if i == j:
|
|
continue
|
|
start_i, end_i = cumsum[i], cumsum[i+1]
|
|
start_j, end_j = cumsum[j], cumsum[j+1]
|
|
attention_mask[:, start_i:end_i, start_j:end_j] = False
|
|
|
|
attention_mask = attention_mask.float()
|
|
attention_mask[attention_mask == 0] = float('-inf')
|
|
attention_mask[attention_mask == 1] = 0
|
|
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
|
|
|
|
return all_prompt_emb, image_rotary_emb, attention_mask
|
|
|
|
|
|
def forward(
|
|
self,
|
|
latents=None,
|
|
timestep=None,
|
|
prompt_emb=None,
|
|
prompt_emb_mask=None,
|
|
height=None,
|
|
width=None,
|
|
):
|
|
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
|
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
|
|
|
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
|
image = self.img_in(image)
|
|
text = self.txt_in(self.txt_norm(prompt_emb))
|
|
|
|
conditioning = self.time_text_embed(timestep, image.dtype)
|
|
|
|
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
|
|
|
for block in self.transformer_blocks:
|
|
text, image = block(
|
|
image=image,
|
|
text=text,
|
|
temb=conditioning,
|
|
image_rotary_emb=image_rotary_emb,
|
|
)
|
|
|
|
image = self.norm_out(image, conditioning)
|
|
image = self.proj_out(image)
|
|
|
|
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
|
return image
|
|
|
|
@staticmethod
|
|
def state_dict_converter():
|
|
return QwenImageDiTStateDictConverter()
|
|
|
|
|
|
|
|
class QwenImageDiTStateDictConverter():
|
|
def __init__(self):
|
|
pass
|
|
|
|
def from_civitai(self, state_dict):
|
|
return state_dict
|