Files
DiffSynth-Studio/diffsynth/models/qwen_image_accelerate_adapter.py
2025-08-08 16:51:42 +08:00

64 lines
1.9 KiB
Python

from .qwen_image_dit import QwenImageTransformerBlock, AdaLayerNorm, TimestepEmbeddings
from einops import rearrange
import torch
class QwenImageAccelerateAdapter(torch.nn.Module):
def __init__(
self,
num_layers: int = 1,
):
super().__init__()
self.proj_latents_in = torch.nn.Linear(64, 3072)
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
self.transformer_blocks = torch.nn.ModuleList(
[
QwenImageTransformerBlock(
dim=3072,
num_attention_heads=24,
attention_head_dim=128,
)
for _ in range(num_layers)
]
)
self.norm_out = AdaLayerNorm(3072, single=True)
self.proj_out = torch.nn.Linear(3072, 64)
self.proj_latents_out = torch.nn.Linear(64, 64)
def forward(
self,
latents=None,
image=None,
text=None,
image_rotary_emb=None,
timestep=None,
):
latents = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
image = image + self.proj_latents_in(latents)
conditioning = self.time_text_embed(timestep, image.dtype)
for block in self.transformer_blocks:
text, image = block(
image=image,
text=text,
temb=conditioning,
image_rotary_emb=image_rotary_emb,
)
image = self.norm_out(image, conditioning)
image = self.proj_out(image)
image = image + self.proj_latents_out(latents)
return image
@staticmethod
def state_dict_converter():
return QwenImageAccelerateAdapterStateDictConverter()
class QwenImageAccelerateAdapterStateDictConverter():
def __init__(self):
pass
def from_civitai(self, state_dict):
return state_dict